Fixed Memory source/sink/store so that it supports multiple versions

of objects.  Fixed several bugs too.
stix2.0
Michael Chisholm 2018-10-15 17:57:57 -04:00
parent 5a0e102959
commit d9f6a213c1
4 changed files with 128 additions and 78 deletions

View File

@ -12,16 +12,17 @@ Note:
""" """
import itertools
import json import json
import os import os
from stix2.base import _STIXBase from stix2.base import _STIXBase
from stix2.core import Bundle, parse from stix2.core import Bundle, parse
from stix2.datastore import DataSink, DataSource, DataStoreMixin from stix2.datastore import DataSink, DataSource, DataStoreMixin
from stix2.datastore.filters import Filter, FilterSet, apply_common_filters from stix2.datastore.filters import FilterSet, apply_common_filters
def _add(store, stix_data=None, version=None): def _add(store, stix_data=None, allow_custom=True, version=None):
"""Add STIX objects to MemoryStore/Sink. """Add STIX objects to MemoryStore/Sink.
Adds STIX objects to an in-memory dictionary for fast lookup. Adds STIX objects to an in-memory dictionary for fast lookup.
@ -33,27 +34,55 @@ def _add(store, stix_data=None, version=None):
None, use latest version. None, use latest version.
""" """
if isinstance(stix_data, _STIXBase): if isinstance(stix_data, list):
# adding a python STIX object
store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, dict):
if stix_data["type"] == "bundle":
# adding a json bundle - so just grab STIX objects
for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj, version=version)
else:
# adding a json STIX object
store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, list):
# STIX objects are in a list- recurse on each object # STIX objects are in a list- recurse on each object
for stix_obj in stix_data: for stix_obj in stix_data:
_add(store, stix_obj, version=version) _add(store, stix_obj, allow_custom, version)
elif stix_data["type"] == "bundle":
# adding a json bundle - so just grab STIX objects
for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj, allow_custom, version)
else: else:
raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of)," # Adding a single non-bundle object
" or a JSON formatted STIX bundle. stix_data was of type: " + str(type(stix_data))) if isinstance(stix_data, _STIXBase):
stix_obj = stix_data
else:
stix_obj = parse(stix_data, allow_custom, version)
if stix_obj.id in store._data:
obj_family = store._data[stix_obj.id]
else:
obj_family = _ObjectFamily()
store._data[stix_obj.id] = obj_family
obj_family.add(stix_obj)
class _ObjectFamily(object):
"""
An internal implementation detail of memory sources/sinks/stores.
Represents a "family" of STIX objects: all objects with a particular
ID. (I.e. all versions.) The latest version is also tracked so that it
can be obtained quickly.
"""
def __init__(self):
self.all_versions = {}
self.latest_version = None
def add(self, obj):
self.all_versions[obj.modified] = obj
if self.latest_version is None or \
obj.modified > self.latest_version.modified:
self.latest_version = obj
def __str__(self):
return "<<{}; latest={}>>".format(self.all_versions,
self.latest_version.modified)
def __repr__(self):
return str(self)
class MemoryStore(DataStoreMixin): class MemoryStore(DataStoreMixin):
@ -83,7 +112,7 @@ class MemoryStore(DataStoreMixin):
self._data = {} self._data = {}
if stix_data: if stix_data:
_add(self, stix_data, version=version) _add(self, stix_data, allow_custom, version=version)
super(MemoryStore, self).__init__( super(MemoryStore, self).__init__(
source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True), source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True),
@ -138,25 +167,31 @@ class MemorySink(DataSink):
""" """
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
super(MemorySink, self).__init__() super(MemorySink, self).__init__()
self._data = {}
self.allow_custom = allow_custom self.allow_custom = allow_custom
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: else:
_add(self, stix_data, version=version) self._data = {}
if stix_data:
_add(self, stix_data, allow_custom, version=version)
def add(self, stix_data, version=None): def add(self, stix_data, version=None):
_add(self, stix_data, version=version) _add(self, stix_data, self.allow_custom, version)
add.__doc__ = _add.__doc__ add.__doc__ = _add.__doc__
def save_to_file(self, file_path): def save_to_file(self, file_path):
file_path = os.path.abspath(file_path) file_path = os.path.abspath(file_path)
all_objs = itertools.chain.from_iterable(
obj_family.all_versions.values()
for obj_family in self._data.values()
)
if not os.path.exists(os.path.dirname(file_path)): if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path)) os.makedirs(os.path.dirname(file_path))
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom))) f.write(str(Bundle(list(all_objs), allow_custom=self.allow_custom)))
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__ save_to_file.__doc__ = MemoryStore.save_to_file.__doc__
@ -184,13 +219,14 @@ class MemorySource(DataSource):
""" """
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
super(MemorySource, self).__init__() super(MemorySource, self).__init__()
self._data = {}
self.allow_custom = allow_custom self.allow_custom = allow_custom
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: else:
_add(self, stix_data, version=version) self._data = {}
if stix_data:
_add(self, stix_data, allow_custom, version=version)
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
"""Retrieve STIX object from in-memory dict via STIX ID. """Retrieve STIX object from in-memory dict via STIX ID.
@ -207,26 +243,22 @@ class MemorySource(DataSource):
is returned in the same form as it as added is returned in the same form as it as added
""" """
if _composite_filters is None: stix_obj = None
# if get call is only based on 'id', no need to search, just retrieve from dict object_family = self._data.get(stix_id)
try: if object_family:
stix_obj = self._data[stix_id] stix_obj = object_family.latest_version
except KeyError:
stix_obj = None
return stix_obj
# if there are filters from the composite level, process full query if stix_obj:
query = [Filter("id", "=", stix_id)] all_filters = list(
itertools.chain(
_composite_filters or [],
self.filters
)
)
all_data = self.query(query=query, _composite_filters=_composite_filters) stix_obj = next(apply_common_filters([stix_obj], all_filters), None)
if all_data: return stix_obj
# reduce to most recent version
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
return stix_obj
else:
return None
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it """Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
@ -246,8 +278,23 @@ class MemorySource(DataSource):
is returned in the same form as it as added is returned in the same form as it as added
""" """
results = []
object_family = self._data.get(stix_id)
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] if object_family:
all_filters = list(
itertools.chain(
_composite_filters or [],
self.filters
)
)
results.extend(
apply_common_filters(object_family.all_versions.values(),
all_filters)
)
return results
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None):
"""Search and retrieve STIX objects based on the complete query. """Search and retrieve STIX objects based on the complete query.
@ -265,7 +312,7 @@ class MemorySource(DataSource):
(list): list of STIX objects that matches the supplied (list): list of STIX objects that matches the supplied
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
as they are supplied (either as python dictionary or STIX object), it as they are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added. is returned in the same form as it was added.
""" """
query = FilterSet(query) query = FilterSet(query)
@ -276,17 +323,23 @@ class MemorySource(DataSource):
if _composite_filters: if _composite_filters:
query.add(_composite_filters) query.add(_composite_filters)
all_objs = itertools.chain.from_iterable(
obj_family.all_versions.values()
for obj_family in self._data.values()
)
# Apply STIX common property filters. # Apply STIX common property filters.
all_data = list(apply_common_filters(self._data.values(), query)) all_data = list(apply_common_filters(all_objs, query))
return all_data return all_data
def load_from_file(self, file_path, version=None): def load_from_file(self, file_path, version=None):
stix_data = json.load(open(os.path.abspath(file_path), "r")) with open(os.path.abspath(file_path), "r") as f:
stix_data = json.load(f)
# Override user version selection if loading a bundle
if stix_data["type"] == "bundle": if stix_data["type"] == "bundle":
for stix_obj in stix_data["objects"]: version = stix_data["spec_version"]
_add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom, version=stix_data["spec_version"]))
else: _add(self, stix_data, self.allow_custom, version)
_add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version))
load_from_file.__doc__ = MemoryStore.load_from_file.__doc__ load_from_file.__doc__ = MemoryStore.load_from_file.__doc__

View File

@ -3,6 +3,7 @@ import pytest
from stix2.datastore import CompositeDataSource, make_id from stix2.datastore import CompositeDataSource, make_id
from stix2.datastore.filters import Filter from stix2.datastore.filters import Filter
from stix2.datastore.memory import MemorySink, MemorySource from stix2.datastore.memory import MemorySink, MemorySource
from stix2.utils import parse_into_datetime
def test_add_remove_composite_datasource(): def test_add_remove_composite_datasource():
@ -44,14 +45,14 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001") indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
# In STIX_OBJS2 changed the 'modified' property to a later time... # In STIX_OBJS2 changed the 'modified' property to a later time...
assert len(indicators) == 2 assert len(indicators) == 3
cds1.add_data_sources([cds2]) cds1.add_data_sources([cds2])
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001") indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001" assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
assert indicator["modified"] == "2017-01-31T13:49:53.935Z" assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
assert indicator["type"] == "indicator" assert indicator["type"] == "indicator"
query1 = [ query1 = [
@ -68,20 +69,18 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
# STIX_OBJS2 has indicator with later time, one with different id, one with # STIX_OBJS2 has indicator with later time, one with different id, one with
# original time in STIX_OBJS1 # original time in STIX_OBJS1
assert len(results) == 3 assert len(results) == 4
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001") indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001" assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
assert indicator["modified"] == "2017-01-31T13:49:53.935Z" assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
assert indicator["type"] == "indicator" assert indicator["type"] == "indicator"
# There is only one indicator with different ID. Since we use the same data
# when deduplicated, only two indicators (one with different modified).
results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001") results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
assert len(results) == 2 assert len(results) == 3
# Since we have filters already associated with our CompositeSource providing # Since we have filters already associated with our CompositeSource providing
# nothing returns the same as cds1.query(query1) (the associated query is query2) # nothing returns the same as cds1.query(query1) (the associated query is query2)
results = cds1.query([]) results = cds1.query([])
assert len(results) == 3 assert len(results) == 4

View File

@ -113,7 +113,7 @@ def test_environment_functions():
# Get both versions of the object # Get both versions of the object
resp = env.all_versions(INDICATOR_ID) resp = env.all_versions(INDICATOR_ID)
assert len(resp) == 1 # should be 2, but MemoryStore only keeps 1 version of objects assert len(resp) == 2
# Get just the most recent version of the object # Get just the most recent version of the object
resp = env.get(INDICATOR_ID) resp = env.get(INDICATOR_ID)

View File

@ -11,6 +11,7 @@ from stix2.datastore import make_id
from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID, from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID,
IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS, IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS,
MALWARE_ID, MALWARE_KWARGS, RELATIONSHIP_IDS) MALWARE_ID, MALWARE_KWARGS, RELATIONSHIP_IDS)
from stix2.utils import parse_into_datetime
IND1 = { IND1 = {
"created": "2017-01-27T13:49:53.935Z", "created": "2017-01-27T13:49:53.935Z",
@ -167,7 +168,7 @@ def test_memory_store_all_versions(mem_store):
type="bundle")) type="bundle"))
resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001") resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001")
assert len(resp) == 1 # MemoryStore can only store 1 version of each object assert len(resp) == 3
def test_memory_store_query(mem_store): def test_memory_store_query(mem_store):
@ -179,25 +180,27 @@ def test_memory_store_query(mem_store):
def test_memory_store_query_single_filter(mem_store): def test_memory_store_query_single_filter(mem_store):
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001') query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
resp = mem_store.query(query) resp = mem_store.query(query)
assert len(resp) == 1 assert len(resp) == 2
def test_memory_store_query_empty_query(mem_store): def test_memory_store_query_empty_query(mem_store):
resp = mem_store.query() resp = mem_store.query()
# sort since returned in random order # sort since returned in random order
resp = sorted(resp, key=lambda k: k['id']) resp = sorted(resp, key=lambda k: (k['id'], k['modified']))
assert len(resp) == 2 assert len(resp) == 3
assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001' assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
assert resp[0]['modified'] == '2017-01-27T13:49:53.936Z' assert resp[0]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000002' assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
assert resp[1]['modified'] == '2017-01-27T13:49:53.935Z' assert resp[1]['modified'] == parse_into_datetime('2017-01-27T13:49:53.936Z')
assert resp[2]['id'] == 'indicator--00000000-0000-4000-8000-000000000002'
assert resp[2]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
def test_memory_store_query_multiple_filters(mem_store): def test_memory_store_query_multiple_filters(mem_store):
mem_store.source.filters.add(Filter('type', '=', 'indicator')) mem_store.source.filters.add(Filter('type', '=', 'indicator'))
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001') query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
resp = mem_store.query(query) resp = mem_store.query(query)
assert len(resp) == 1 assert len(resp) == 2
def test_memory_store_save_load_file(mem_store, fs_mem_store): def test_memory_store_save_load_file(mem_store, fs_mem_store):
@ -218,12 +221,8 @@ def test_memory_store_save_load_file(mem_store, fs_mem_store):
def test_memory_store_add_invalid_object(mem_store): def test_memory_store_add_invalid_object(mem_store):
ind = ('indicator', IND1) # tuple isn't valid ind = ('indicator', IND1) # tuple isn't valid
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError):
mem_store.add(ind) mem_store.add(ind)
assert 'stix_data expected to be' in str(excinfo.value)
assert 'a python-stix2 object' in str(excinfo.value)
assert 'JSON formatted STIX' in str(excinfo.value)
assert 'JSON formatted STIX bundle' in str(excinfo.value)
def test_memory_store_object_with_custom_property(mem_store): def test_memory_store_object_with_custom_property(mem_store):
@ -246,10 +245,9 @@ def test_memory_store_object_with_custom_property_in_bundle(mem_store):
allow_custom=True) allow_custom=True)
bundle = Bundle(camp, allow_custom=True) bundle = Bundle(camp, allow_custom=True)
mem_store.add(bundle, True) mem_store.add(bundle)
bundle_r = mem_store.get(bundle.id) camp_r = mem_store.get(camp.id)
camp_r = bundle_r['objects'][0]
assert camp_r.id == camp.id assert camp_r.id == camp.id
assert camp_r.x_empire == camp.x_empire assert camp_r.x_empire == camp.x_empire