diff --git a/stix2/datastore/memory.py b/stix2/datastore/memory.py index 52bf4c8..52da168 100644 --- a/stix2/datastore/memory.py +++ b/stix2/datastore/memory.py @@ -10,7 +10,6 @@ from stix2.base import _STIXBase from stix2.core import parse from stix2.datastore import DataSink, DataSource, DataStoreMixin from stix2.datastore.filters import FilterSet, apply_common_filters -from stix2.utils import is_marking def _add(store, stix_data, allow_custom=True, version=None): @@ -47,12 +46,10 @@ def _add(store, stix_data, allow_custom=True, version=None): else: stix_obj = parse(stix_data, allow_custom, version) - # Map ID directly to the object, if it is a marking. Otherwise, - # map to a family, so we can track multiple versions. - if is_marking(stix_obj): - store._data[stix_obj["id"]] = stix_obj - - else: + # Map ID to a _ObjectFamily if the object is versioned, so we can track + # multiple versions. Otherwise, map directly to the object. All + # versioned objects should have a "modified" property. + if "modified" in stix_obj: if stix_obj["id"] in store._data: obj_family = store._data[stix_obj["id"]] else: @@ -61,6 +58,9 @@ def _add(store, stix_data, allow_custom=True, version=None): obj_family.add(stix_obj) + else: + store._data[stix_obj["id"]] = stix_obj + class _ObjectFamily(object): """ @@ -267,12 +267,12 @@ class MemorySource(DataSource): """ stix_obj = None - if is_marking(stix_id): - stix_obj = self._data.get(stix_id) - else: - object_family = self._data.get(stix_id) - if object_family: - stix_obj = object_family.latest_version + mapped_value = self._data.get(stix_id) + if mapped_value: + if isinstance(mapped_value, _ObjectFamily): + stix_obj = mapped_value.latest_version + else: + stix_obj = mapped_value if stix_obj: all_filters = list( @@ -300,17 +300,13 @@ class MemorySource(DataSource): """ results = [] - stix_objs_to_filter = None - if is_marking(stix_id): - stix_obj = self._data.get(stix_id) - if stix_obj: - stix_objs_to_filter = [stix_obj] - else: - object_family = self._data.get(stix_id) - if object_family: - stix_objs_to_filter = object_family.all_versions.values() + mapped_value = self._data.get(stix_id) + if mapped_value: + if isinstance(mapped_value, _ObjectFamily): + stix_objs_to_filter = mapped_value.all_versions.values() + else: + stix_objs_to_filter = [mapped_value] - if stix_objs_to_filter: all_filters = list( itertools.chain( _composite_filters or [], diff --git a/stix2/test/v20/test_datastore_memory.py b/stix2/test/v20/test_datastore_memory.py index fba96dd..c86f94d 100644 --- a/stix2/test/v20/test_datastore_memory.py +++ b/stix2/test/v20/test_datastore_memory.py @@ -423,3 +423,24 @@ def test_object_family_internal_components(mem_source): assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation + + +def test_unversioned_objects(mem_store): + marking = { + "type": "marking-definition", + "id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62", + "created": "1988-01-02T16:44:04.000Z", + "definition_type": "statement", + "definition": { + "statement": "Copyright (C) ACME Corp." + } + } + + mem_store.add(marking) + + obj = mem_store.get(marking["id"]) + assert obj["id"] == marking["id"] + + objs = mem_store.all_versions(marking["id"]) + assert len(objs) == 1 + assert objs[0]["id"] == marking["id"] diff --git a/stix2/test/v21/test_datastore_memory.py b/stix2/test/v21/test_datastore_memory.py index 4f63a06..ad61431 100644 --- a/stix2/test/v21/test_datastore_memory.py +++ b/stix2/test/v21/test_datastore_memory.py @@ -438,3 +438,38 @@ def test_object_family_internal_components(mem_source): assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation + + +def test_unversioned_objects(mem_store): + marking = { + "type": "marking-definition", + "spec_version": "2.1", + "id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62", + "created": "1988-01-02T16:44:04.000Z", + "definition_type": "statement", + "definition": { + "statement": "Copyright (C) ACME Corp." + } + } + + file_sco = { + "type": "file", + "id": "file--bbd59c0c-1aa4-44f1-96de-80b8325372c7", + "name": "cats.png" + } + + mem_store.add([marking, file_sco]) + + obj = mem_store.get(marking["id"]) + assert obj["id"] == marking["id"] + + obj = mem_store.get(file_sco["id"]) + assert obj["id"] == file_sco["id"] + + objs = mem_store.all_versions(marking["id"]) + assert len(objs) == 1 + assert objs[0]["id"] == marking["id"] + + objs = mem_store.all_versions(file_sco["id"]) + assert len(objs) == 1 + assert objs[0]["id"] == file_sco["id"]