Added support to multi-version memory stores for markings. Also
added some more unit tests which test storing/retrieving markings from the stores.stix2.0
parent
2d89cfb0cf
commit
cbe8d22d0a
|
@ -41,13 +41,35 @@ def _add(store, stix_data=None, allow_custom=True, version=None):
|
|||
else:
|
||||
stix_obj = parse(stix_data, allow_custom, version)
|
||||
|
||||
if stix_obj.id in store._data:
|
||||
obj_family = store._data[stix_obj.id]
|
||||
else:
|
||||
obj_family = _ObjectFamily()
|
||||
store._data[stix_obj.id] = obj_family
|
||||
# Map ID directly to the object, if it is a marking. Otherwise,
|
||||
# map to a family, so we can track multiple versions.
|
||||
if _is_marking(stix_obj):
|
||||
store._data[stix_obj.id] = stix_obj
|
||||
|
||||
obj_family.add(stix_obj)
|
||||
else:
|
||||
if stix_obj.id in store._data:
|
||||
obj_family = store._data[stix_obj.id]
|
||||
else:
|
||||
obj_family = _ObjectFamily()
|
||||
store._data[stix_obj.id] = obj_family
|
||||
|
||||
obj_family.add(stix_obj)
|
||||
|
||||
|
||||
def _is_marking(obj_or_id):
|
||||
"""Determines whether the given object or object ID is/is for a marking
|
||||
definition.
|
||||
|
||||
:param obj_or_id: A STIX object or object ID as a string.
|
||||
:return: True if a marking definition, False otherwise.
|
||||
"""
|
||||
|
||||
if isinstance(obj_or_id, _STIXBase):
|
||||
id_ = obj_or_id.id
|
||||
else:
|
||||
id_ = obj_or_id
|
||||
|
||||
return id_.startswith("marking-definition--")
|
||||
|
||||
|
||||
class _ObjectFamily(object):
|
||||
|
@ -174,8 +196,9 @@ class MemorySink(DataSink):
|
|||
file_path = os.path.abspath(file_path)
|
||||
|
||||
all_objs = itertools.chain.from_iterable(
|
||||
obj_family.all_versions.values()
|
||||
for obj_family in self._data.values()
|
||||
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||
else [value]
|
||||
for value in self._data.values()
|
||||
)
|
||||
|
||||
if not os.path.exists(os.path.dirname(file_path)):
|
||||
|
@ -234,9 +257,13 @@ class MemorySource(DataSource):
|
|||
|
||||
"""
|
||||
stix_obj = None
|
||||
object_family = self._data.get(stix_id)
|
||||
if object_family:
|
||||
stix_obj = object_family.latest_version
|
||||
|
||||
if _is_marking(stix_id):
|
||||
stix_obj = self._data.get(stix_id)
|
||||
else:
|
||||
object_family = self._data.get(stix_id)
|
||||
if object_family:
|
||||
stix_obj = object_family.latest_version
|
||||
|
||||
if stix_obj:
|
||||
all_filters = list(
|
||||
|
@ -269,9 +296,17 @@ class MemorySource(DataSource):
|
|||
|
||||
"""
|
||||
results = []
|
||||
object_family = self._data.get(stix_id)
|
||||
stix_objs_to_filter = None
|
||||
if _is_marking(stix_id):
|
||||
stix_obj = self._data.get(stix_id)
|
||||
if stix_obj:
|
||||
stix_objs_to_filter = [stix_obj]
|
||||
else:
|
||||
object_family = self._data.get(stix_id)
|
||||
if object_family:
|
||||
stix_objs_to_filter = object_family.all_versions.values()
|
||||
|
||||
if object_family:
|
||||
if stix_objs_to_filter:
|
||||
all_filters = list(
|
||||
itertools.chain(
|
||||
_composite_filters or [],
|
||||
|
@ -280,8 +315,7 @@ class MemorySource(DataSource):
|
|||
)
|
||||
|
||||
results.extend(
|
||||
apply_common_filters(object_family.all_versions.values(),
|
||||
all_filters)
|
||||
apply_common_filters(stix_objs_to_filter, all_filters)
|
||||
)
|
||||
|
||||
return results
|
||||
|
@ -314,8 +348,9 @@ class MemorySource(DataSource):
|
|||
query.add(_composite_filters)
|
||||
|
||||
all_objs = itertools.chain.from_iterable(
|
||||
obj_family.all_versions.values()
|
||||
for obj_family in self._data.values()
|
||||
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||
else [value]
|
||||
for value in self._data.values()
|
||||
)
|
||||
|
||||
# Apply STIX common property filters.
|
||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
|||
|
||||
from stix2.datastore import CompositeDataSource, make_id
|
||||
from stix2.datastore.filters import Filter
|
||||
from stix2.datastore.memory import MemorySink, MemorySource
|
||||
from stix2.datastore.memory import MemorySink, MemorySource, MemoryStore
|
||||
from stix2.utils import parse_into_datetime
|
||||
from stix2.v20.common import TLP_GREEN
|
||||
|
||||
|
||||
def test_add_remove_composite_datasource():
|
||||
|
@ -84,3 +85,65 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
|||
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
||||
results = cds1.query([])
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
def test_source_markings():
|
||||
msrc = MemorySource(TLP_GREEN)
|
||||
|
||||
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||
|
||||
|
||||
def test_sink_markings():
|
||||
# just make sure there is no crash
|
||||
msink = MemorySink(TLP_GREEN)
|
||||
msink.add(TLP_GREEN)
|
||||
|
||||
|
||||
def test_store_markings():
|
||||
mstore = MemoryStore(TLP_GREEN)
|
||||
|
||||
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||
|
||||
|
||||
def test_source_mixed(indicator):
|
||||
msrc = MemorySource([TLP_GREEN, indicator])
|
||||
|
||||
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||
|
||||
assert msrc.get(indicator.id) == indicator
|
||||
assert msrc.all_versions(indicator.id) == [indicator]
|
||||
assert msrc.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||
|
||||
all_objs = msrc.query()
|
||||
assert TLP_GREEN in all_objs
|
||||
assert indicator in all_objs
|
||||
assert len(all_objs) == 2
|
||||
|
||||
|
||||
def test_sink_mixed(indicator):
|
||||
# just make sure there is no crash
|
||||
msink = MemorySink([TLP_GREEN, indicator])
|
||||
msink.add([TLP_GREEN, indicator])
|
||||
|
||||
|
||||
def test_store_mixed(indicator):
|
||||
mstore = MemoryStore([TLP_GREEN, indicator])
|
||||
|
||||
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||
|
||||
assert mstore.get(indicator.id) == indicator
|
||||
assert mstore.all_versions(indicator.id) == [indicator]
|
||||
assert mstore.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||
|
||||
all_objs = mstore.query()
|
||||
assert TLP_GREEN in all_objs
|
||||
assert indicator in all_objs
|
||||
assert len(all_objs) == 2
|
||||
|
|
Loading…
Reference in New Issue