Added support to multi-version memory stores for markings. Also
added some more unit tests which test storing/retrieving markings from the stores.stix2.0
parent
2d89cfb0cf
commit
cbe8d22d0a
|
@ -41,6 +41,12 @@ def _add(store, stix_data=None, allow_custom=True, version=None):
|
||||||
else:
|
else:
|
||||||
stix_obj = parse(stix_data, allow_custom, version)
|
stix_obj = parse(stix_data, allow_custom, version)
|
||||||
|
|
||||||
|
# Map ID directly to the object, if it is a marking. Otherwise,
|
||||||
|
# map to a family, so we can track multiple versions.
|
||||||
|
if _is_marking(stix_obj):
|
||||||
|
store._data[stix_obj.id] = stix_obj
|
||||||
|
|
||||||
|
else:
|
||||||
if stix_obj.id in store._data:
|
if stix_obj.id in store._data:
|
||||||
obj_family = store._data[stix_obj.id]
|
obj_family = store._data[stix_obj.id]
|
||||||
else:
|
else:
|
||||||
|
@ -50,6 +56,22 @@ def _add(store, stix_data=None, allow_custom=True, version=None):
|
||||||
obj_family.add(stix_obj)
|
obj_family.add(stix_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_marking(obj_or_id):
|
||||||
|
"""Determines whether the given object or object ID is/is for a marking
|
||||||
|
definition.
|
||||||
|
|
||||||
|
:param obj_or_id: A STIX object or object ID as a string.
|
||||||
|
:return: True if a marking definition, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(obj_or_id, _STIXBase):
|
||||||
|
id_ = obj_or_id.id
|
||||||
|
else:
|
||||||
|
id_ = obj_or_id
|
||||||
|
|
||||||
|
return id_.startswith("marking-definition--")
|
||||||
|
|
||||||
|
|
||||||
class _ObjectFamily(object):
|
class _ObjectFamily(object):
|
||||||
"""
|
"""
|
||||||
An internal implementation detail of memory sources/sinks/stores.
|
An internal implementation detail of memory sources/sinks/stores.
|
||||||
|
@ -174,8 +196,9 @@ class MemorySink(DataSink):
|
||||||
file_path = os.path.abspath(file_path)
|
file_path = os.path.abspath(file_path)
|
||||||
|
|
||||||
all_objs = itertools.chain.from_iterable(
|
all_objs = itertools.chain.from_iterable(
|
||||||
obj_family.all_versions.values()
|
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||||
for obj_family in self._data.values()
|
else [value]
|
||||||
|
for value in self._data.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.exists(os.path.dirname(file_path)):
|
if not os.path.exists(os.path.dirname(file_path)):
|
||||||
|
@ -234,6 +257,10 @@ class MemorySource(DataSource):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stix_obj = None
|
stix_obj = None
|
||||||
|
|
||||||
|
if _is_marking(stix_id):
|
||||||
|
stix_obj = self._data.get(stix_id)
|
||||||
|
else:
|
||||||
object_family = self._data.get(stix_id)
|
object_family = self._data.get(stix_id)
|
||||||
if object_family:
|
if object_family:
|
||||||
stix_obj = object_family.latest_version
|
stix_obj = object_family.latest_version
|
||||||
|
@ -269,9 +296,17 @@ class MemorySource(DataSource):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
|
stix_objs_to_filter = None
|
||||||
|
if _is_marking(stix_id):
|
||||||
|
stix_obj = self._data.get(stix_id)
|
||||||
|
if stix_obj:
|
||||||
|
stix_objs_to_filter = [stix_obj]
|
||||||
|
else:
|
||||||
object_family = self._data.get(stix_id)
|
object_family = self._data.get(stix_id)
|
||||||
|
|
||||||
if object_family:
|
if object_family:
|
||||||
|
stix_objs_to_filter = object_family.all_versions.values()
|
||||||
|
|
||||||
|
if stix_objs_to_filter:
|
||||||
all_filters = list(
|
all_filters = list(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
_composite_filters or [],
|
_composite_filters or [],
|
||||||
|
@ -280,8 +315,7 @@ class MemorySource(DataSource):
|
||||||
)
|
)
|
||||||
|
|
||||||
results.extend(
|
results.extend(
|
||||||
apply_common_filters(object_family.all_versions.values(),
|
apply_common_filters(stix_objs_to_filter, all_filters)
|
||||||
all_filters)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -314,8 +348,9 @@ class MemorySource(DataSource):
|
||||||
query.add(_composite_filters)
|
query.add(_composite_filters)
|
||||||
|
|
||||||
all_objs = itertools.chain.from_iterable(
|
all_objs = itertools.chain.from_iterable(
|
||||||
obj_family.all_versions.values()
|
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||||
for obj_family in self._data.values()
|
else [value]
|
||||||
|
for value in self._data.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply STIX common property filters.
|
# Apply STIX common property filters.
|
||||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
||||||
|
|
||||||
from stix2.datastore import CompositeDataSource, make_id
|
from stix2.datastore import CompositeDataSource, make_id
|
||||||
from stix2.datastore.filters import Filter
|
from stix2.datastore.filters import Filter
|
||||||
from stix2.datastore.memory import MemorySink, MemorySource
|
from stix2.datastore.memory import MemorySink, MemorySource, MemoryStore
|
||||||
from stix2.utils import parse_into_datetime
|
from stix2.utils import parse_into_datetime
|
||||||
|
from stix2.v20.common import TLP_GREEN
|
||||||
|
|
||||||
|
|
||||||
def test_add_remove_composite_datasource():
|
def test_add_remove_composite_datasource():
|
||||||
|
@ -84,3 +85,65 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
||||||
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
||||||
results = cds1.query([])
|
results = cds1.query([])
|
||||||
assert len(results) == 4
|
assert len(results) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_source_markings():
|
||||||
|
msrc = MemorySource(TLP_GREEN)
|
||||||
|
|
||||||
|
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sink_markings():
|
||||||
|
# just make sure there is no crash
|
||||||
|
msink = MemorySink(TLP_GREEN)
|
||||||
|
msink.add(TLP_GREEN)
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_markings():
|
||||||
|
mstore = MemoryStore(TLP_GREEN)
|
||||||
|
|
||||||
|
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_source_mixed(indicator):
|
||||||
|
msrc = MemorySource([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
assert msrc.get(indicator.id) == indicator
|
||||||
|
assert msrc.all_versions(indicator.id) == [indicator]
|
||||||
|
assert msrc.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||||
|
|
||||||
|
all_objs = msrc.query()
|
||||||
|
assert TLP_GREEN in all_objs
|
||||||
|
assert indicator in all_objs
|
||||||
|
assert len(all_objs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sink_mixed(indicator):
|
||||||
|
# just make sure there is no crash
|
||||||
|
msink = MemorySink([TLP_GREEN, indicator])
|
||||||
|
msink.add([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_mixed(indicator):
|
||||||
|
mstore = MemoryStore([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
assert mstore.get(indicator.id) == indicator
|
||||||
|
assert mstore.all_versions(indicator.id) == [indicator]
|
||||||
|
assert mstore.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||||
|
|
||||||
|
all_objs = mstore.query()
|
||||||
|
assert TLP_GREEN in all_objs
|
||||||
|
assert indicator in all_objs
|
||||||
|
assert len(all_objs) == 2
|
||||||
|
|
Loading…
Reference in New Issue