Added support to multi-version memory stores for markings. Also

added some more unit tests which test storing/retrieving markings
from the stores.
stix2.0
Michael Chisholm 2018-10-17 20:54:53 -04:00
parent 2d89cfb0cf
commit cbe8d22d0a
2 changed files with 116 additions and 18 deletions

View File

@ -41,13 +41,35 @@ def _add(store, stix_data=None, allow_custom=True, version=None):
else:
stix_obj = parse(stix_data, allow_custom, version)
if stix_obj.id in store._data:
obj_family = store._data[stix_obj.id]
else:
obj_family = _ObjectFamily()
store._data[stix_obj.id] = obj_family
# Map ID directly to the object, if it is a marking. Otherwise,
# map to a family, so we can track multiple versions.
if _is_marking(stix_obj):
store._data[stix_obj.id] = stix_obj
obj_family.add(stix_obj)
else:
if stix_obj.id in store._data:
obj_family = store._data[stix_obj.id]
else:
obj_family = _ObjectFamily()
store._data[stix_obj.id] = obj_family
obj_family.add(stix_obj)
def _is_marking(obj_or_id):
"""Determines whether the given object or object ID is/is for a marking
definition.
:param obj_or_id: A STIX object or object ID as a string.
:return: True if a marking definition, False otherwise.
"""
if isinstance(obj_or_id, _STIXBase):
id_ = obj_or_id.id
else:
id_ = obj_or_id
return id_.startswith("marking-definition--")
class _ObjectFamily(object):
@ -174,8 +196,9 @@ class MemorySink(DataSink):
file_path = os.path.abspath(file_path)
all_objs = itertools.chain.from_iterable(
obj_family.all_versions.values()
for obj_family in self._data.values()
value.all_versions.values() if isinstance(value, _ObjectFamily)
else [value]
for value in self._data.values()
)
if not os.path.exists(os.path.dirname(file_path)):
@ -234,9 +257,13 @@ class MemorySource(DataSource):
"""
stix_obj = None
object_family = self._data.get(stix_id)
if object_family:
stix_obj = object_family.latest_version
if _is_marking(stix_id):
stix_obj = self._data.get(stix_id)
else:
object_family = self._data.get(stix_id)
if object_family:
stix_obj = object_family.latest_version
if stix_obj:
all_filters = list(
@ -269,9 +296,17 @@ class MemorySource(DataSource):
"""
results = []
object_family = self._data.get(stix_id)
stix_objs_to_filter = None
if _is_marking(stix_id):
stix_obj = self._data.get(stix_id)
if stix_obj:
stix_objs_to_filter = [stix_obj]
else:
object_family = self._data.get(stix_id)
if object_family:
stix_objs_to_filter = object_family.all_versions.values()
if object_family:
if stix_objs_to_filter:
all_filters = list(
itertools.chain(
_composite_filters or [],
@ -280,8 +315,7 @@ class MemorySource(DataSource):
)
results.extend(
apply_common_filters(object_family.all_versions.values(),
all_filters)
apply_common_filters(stix_objs_to_filter, all_filters)
)
return results
@ -314,8 +348,9 @@ class MemorySource(DataSource):
query.add(_composite_filters)
all_objs = itertools.chain.from_iterable(
obj_family.all_versions.values()
for obj_family in self._data.values()
value.all_versions.values() if isinstance(value, _ObjectFamily)
else [value]
for value in self._data.values()
)
# Apply STIX common property filters.

View File

@ -2,8 +2,9 @@ import pytest
from stix2.datastore import CompositeDataSource, make_id
from stix2.datastore.filters import Filter
from stix2.datastore.memory import MemorySink, MemorySource
from stix2.datastore.memory import MemorySink, MemorySource, MemoryStore
from stix2.utils import parse_into_datetime
from stix2.v20.common import TLP_GREEN
def test_add_remove_composite_datasource():
@ -84,3 +85,65 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
# nothing returns the same as cds1.query(query1) (the associated query is query2)
results = cds1.query([])
assert len(results) == 4
def test_source_markings():
msrc = MemorySource(TLP_GREEN)
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
def test_sink_markings():
# just make sure there is no crash
msink = MemorySink(TLP_GREEN)
msink.add(TLP_GREEN)
def test_store_markings():
mstore = MemoryStore(TLP_GREEN)
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
def test_source_mixed(indicator):
msrc = MemorySource([TLP_GREEN, indicator])
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
assert msrc.get(indicator.id) == indicator
assert msrc.all_versions(indicator.id) == [indicator]
assert msrc.query(Filter("id", "=", indicator.id)) == [indicator]
all_objs = msrc.query()
assert TLP_GREEN in all_objs
assert indicator in all_objs
assert len(all_objs) == 2
def test_sink_mixed(indicator):
# just make sure there is no crash
msink = MemorySink([TLP_GREEN, indicator])
msink.add([TLP_GREEN, indicator])
def test_store_mixed(indicator):
mstore = MemoryStore([TLP_GREEN, indicator])
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
assert mstore.get(indicator.id) == indicator
assert mstore.all_versions(indicator.id) == [indicator]
assert mstore.query(Filter("id", "=", indicator.id)) == [indicator]
all_objs = mstore.query()
assert TLP_GREEN in all_objs
assert indicator in all_objs
assert len(all_objs) == 2