Merge pull request #342 from chisholm/sco_tlo_memorystore

Fix the memory store to support the new top-level 2.1 SCOs.
master
Chris Lenk 2020-02-14 10:16:06 -05:00 committed by GitHub
commit 8aca39a0b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 23 deletions

View File

@ -10,7 +10,6 @@ from stix2.base import _STIXBase
from stix2.core import parse
from stix2.datastore import DataSink, DataSource, DataStoreMixin
from stix2.datastore.filters import FilterSet, apply_common_filters
from stix2.utils import is_marking
def _add(store, stix_data, allow_custom=True, version=None):
@ -47,12 +46,10 @@ def _add(store, stix_data, allow_custom=True, version=None):
else:
stix_obj = parse(stix_data, allow_custom, version)
# Map ID directly to the object, if it is a marking. Otherwise,
# map to a family, so we can track multiple versions.
if is_marking(stix_obj):
store._data[stix_obj["id"]] = stix_obj
else:
# Map ID to a _ObjectFamily if the object is versioned, so we can track
# multiple versions. Otherwise, map directly to the object. All
# versioned objects should have a "modified" property.
if "modified" in stix_obj:
if stix_obj["id"] in store._data:
obj_family = store._data[stix_obj["id"]]
else:
@ -61,6 +58,9 @@ def _add(store, stix_data, allow_custom=True, version=None):
obj_family.add(stix_obj)
else:
store._data[stix_obj["id"]] = stix_obj
class _ObjectFamily(object):
"""
@ -267,12 +267,12 @@ class MemorySource(DataSource):
"""
stix_obj = None
if is_marking(stix_id):
stix_obj = self._data.get(stix_id)
else:
object_family = self._data.get(stix_id)
if object_family:
stix_obj = object_family.latest_version
mapped_value = self._data.get(stix_id)
if mapped_value:
if isinstance(mapped_value, _ObjectFamily):
stix_obj = mapped_value.latest_version
else:
stix_obj = mapped_value
if stix_obj:
all_filters = list(
@ -300,17 +300,13 @@ class MemorySource(DataSource):
"""
results = []
stix_objs_to_filter = None
if is_marking(stix_id):
stix_obj = self._data.get(stix_id)
if stix_obj:
stix_objs_to_filter = [stix_obj]
else:
object_family = self._data.get(stix_id)
if object_family:
stix_objs_to_filter = object_family.all_versions.values()
mapped_value = self._data.get(stix_id)
if mapped_value:
if isinstance(mapped_value, _ObjectFamily):
stix_objs_to_filter = mapped_value.all_versions.values()
else:
stix_objs_to_filter = [mapped_value]
if stix_objs_to_filter:
all_filters = list(
itertools.chain(
_composite_filters or [],

View File

@ -423,3 +423,24 @@ def test_object_family_internal_components(mem_source):
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation
def test_unversioned_objects(mem_store):
marking = {
"type": "marking-definition",
"id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62",
"created": "1988-01-02T16:44:04.000Z",
"definition_type": "statement",
"definition": {
"statement": "Copyright (C) ACME Corp.",
},
}
mem_store.add(marking)
obj = mem_store.get(marking["id"])
assert obj["id"] == marking["id"]
objs = mem_store.all_versions(marking["id"])
assert len(objs) == 1
assert objs[0]["id"] == marking["id"]

View File

@ -438,3 +438,38 @@ def test_object_family_internal_components(mem_source):
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in str_representation
assert "latest=2017-01-27 13:49:53.936000+00:00>>" in repr_representation
def test_unversioned_objects(mem_store):
marking = {
"type": "marking-definition",
"spec_version": "2.1",
"id": "marking-definition--48e83cde-e902-4404-85b3-6e81f75ccb62",
"created": "1988-01-02T16:44:04.000Z",
"definition_type": "statement",
"definition": {
"statement": "Copyright (C) ACME Corp.",
},
}
file_sco = {
"type": "file",
"id": "file--bbd59c0c-1aa4-44f1-96de-80b8325372c7",
"name": "cats.png",
}
mem_store.add([marking, file_sco])
obj = mem_store.get(marking["id"])
assert obj["id"] == marking["id"]
obj = mem_store.get(file_sco["id"])
assert obj["id"] == file_sco["id"]
objs = mem_store.all_versions(marking["id"])
assert len(objs) == 1
assert objs[0]["id"] == marking["id"]
objs = mem_store.all_versions(file_sco["id"])
assert len(objs) == 1
assert objs[0]["id"] == file_sco["id"]