diff --git a/stix2/datastore/memory.py b/stix2/datastore/memory.py index c1d202d..98304e0 100644 --- a/stix2/datastore/memory.py +++ b/stix2/datastore/memory.py @@ -12,16 +12,17 @@ Note: """ +import itertools import json import os from stix2.base import _STIXBase from stix2.core import Bundle, parse from stix2.datastore import DataSink, DataSource, DataStoreMixin -from stix2.datastore.filters import Filter, FilterSet, apply_common_filters +from stix2.datastore.filters import FilterSet, apply_common_filters -def _add(store, stix_data=None, version=None): +def _add(store, stix_data=None, allow_custom=True, version=None): """Add STIX objects to MemoryStore/Sink. Adds STIX objects to an in-memory dictionary for fast lookup. @@ -33,27 +34,55 @@ def _add(store, stix_data=None, version=None): None, use latest version. """ - if isinstance(stix_data, _STIXBase): - # adding a python STIX object - store._data[stix_data["id"]] = stix_data - - elif isinstance(stix_data, dict): - if stix_data["type"] == "bundle": - # adding a json bundle - so just grab STIX objects - for stix_obj in stix_data.get("objects", []): - _add(store, stix_obj, version=version) - else: - # adding a json STIX object - store._data[stix_data["id"]] = stix_data - - elif isinstance(stix_data, list): + if isinstance(stix_data, list): # STIX objects are in a list- recurse on each object for stix_obj in stix_data: - _add(store, stix_obj, version=version) + _add(store, stix_obj, allow_custom, version) + + elif stix_data["type"] == "bundle": + # adding a json bundle - so just grab STIX objects + for stix_obj in stix_data.get("objects", []): + _add(store, stix_obj, allow_custom, version) else: - raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of)," - " or a JSON formatted STIX bundle. stix_data was of type: " + str(type(stix_data))) + # Adding a single non-bundle object + if isinstance(stix_data, _STIXBase): + stix_obj = stix_data + else: + stix_obj = parse(stix_data, allow_custom, version) + + if stix_obj.id in store._data: + obj_family = store._data[stix_obj.id] + else: + obj_family = _ObjectFamily() + store._data[stix_obj.id] = obj_family + + obj_family.add(stix_obj) + + +class _ObjectFamily(object): + """ + An internal implementation detail of memory sources/sinks/stores. + Represents a "family" of STIX objects: all objects with a particular + ID. (I.e. all versions.) The latest version is also tracked so that it + can be obtained quickly. + """ + def __init__(self): + self.all_versions = {} + self.latest_version = None + + def add(self, obj): + self.all_versions[obj.modified] = obj + if self.latest_version is None or \ + obj.modified > self.latest_version.modified: + self.latest_version = obj + + def __str__(self): + return "<<{}; latest={}>>".format(self.all_versions, + self.latest_version.modified) + + def __repr__(self): + return str(self) class MemoryStore(DataStoreMixin): @@ -83,7 +112,7 @@ class MemoryStore(DataStoreMixin): self._data = {} if stix_data: - _add(self, stix_data, version=version) + _add(self, stix_data, allow_custom, version=version) super(MemoryStore, self).__init__( source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True), @@ -138,25 +167,31 @@ class MemorySink(DataSink): """ def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): super(MemorySink, self).__init__() - self._data = {} self.allow_custom = allow_custom if _store: self._data = stix_data - elif stix_data: - _add(self, stix_data, version=version) + else: + self._data = {} + if stix_data: + _add(self, stix_data, allow_custom, version=version) def add(self, stix_data, version=None): - _add(self, stix_data, version=version) + _add(self, stix_data, self.allow_custom, version) add.__doc__ = _add.__doc__ def save_to_file(self, file_path): file_path = os.path.abspath(file_path) + all_objs = itertools.chain.from_iterable( + obj_family.all_versions.values() + for obj_family in self._data.values() + ) + if not os.path.exists(os.path.dirname(file_path)): os.makedirs(os.path.dirname(file_path)) with open(file_path, "w") as f: - f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom))) + f.write(str(Bundle(list(all_objs), allow_custom=self.allow_custom))) save_to_file.__doc__ = MemoryStore.save_to_file.__doc__ @@ -184,13 +219,14 @@ class MemorySource(DataSource): """ def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): super(MemorySource, self).__init__() - self._data = {} self.allow_custom = allow_custom if _store: self._data = stix_data - elif stix_data: - _add(self, stix_data, version=version) + else: + self._data = {} + if stix_data: + _add(self, stix_data, allow_custom, version=version) def get(self, stix_id, _composite_filters=None): """Retrieve STIX object from in-memory dict via STIX ID. @@ -207,26 +243,22 @@ class MemorySource(DataSource): is returned in the same form as it as added """ - if _composite_filters is None: - # if get call is only based on 'id', no need to search, just retrieve from dict - try: - stix_obj = self._data[stix_id] - except KeyError: - stix_obj = None - return stix_obj + stix_obj = None + object_family = self._data.get(stix_id) + if object_family: + stix_obj = object_family.latest_version - # if there are filters from the composite level, process full query - query = [Filter("id", "=", stix_id)] + if stix_obj: + all_filters = list( + itertools.chain( + _composite_filters or [], + self.filters + ) + ) - all_data = self.query(query=query, _composite_filters=_composite_filters) + stix_obj = next(apply_common_filters([stix_obj], all_filters), None) - if all_data: - # reduce to most recent version - stix_obj = sorted(all_data, key=lambda k: k['modified'])[0] - - return stix_obj - else: - return None + return stix_obj def all_versions(self, stix_id, _composite_filters=None): """Retrieve STIX objects from in-memory dict via STIX ID, all versions of it @@ -246,8 +278,23 @@ class MemorySource(DataSource): is returned in the same form as it as added """ + results = [] + object_family = self._data.get(stix_id) - return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] + if object_family: + all_filters = list( + itertools.chain( + _composite_filters or [], + self.filters + ) + ) + + results.extend( + apply_common_filters(object_family.all_versions.values(), + all_filters) + ) + + return results def query(self, query=None, _composite_filters=None): """Search and retrieve STIX objects based on the complete query. @@ -265,7 +312,7 @@ class MemorySource(DataSource): (list): list of STIX objects that matches the supplied query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory as they are supplied (either as python dictionary or STIX object), it - is returned in the same form as it as added. + is returned in the same form as it was added. """ query = FilterSet(query) @@ -276,17 +323,23 @@ class MemorySource(DataSource): if _composite_filters: query.add(_composite_filters) + all_objs = itertools.chain.from_iterable( + obj_family.all_versions.values() + for obj_family in self._data.values() + ) + # Apply STIX common property filters. - all_data = list(apply_common_filters(self._data.values(), query)) + all_data = list(apply_common_filters(all_objs, query)) return all_data def load_from_file(self, file_path, version=None): - stix_data = json.load(open(os.path.abspath(file_path), "r")) + with open(os.path.abspath(file_path), "r") as f: + stix_data = json.load(f) + # Override user version selection if loading a bundle if stix_data["type"] == "bundle": - for stix_obj in stix_data["objects"]: - _add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom, version=stix_data["spec_version"])) - else: - _add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version)) + version = stix_data["spec_version"] + + _add(self, stix_data, self.allow_custom, version) load_from_file.__doc__ = MemoryStore.load_from_file.__doc__ diff --git a/stix2/test/test_datastore_memory.py b/stix2/test/test_datastore_memory.py index 3d69953..60ea33b 100644 --- a/stix2/test/test_datastore_memory.py +++ b/stix2/test/test_datastore_memory.py @@ -3,6 +3,7 @@ import pytest from stix2.datastore import CompositeDataSource, make_id from stix2.datastore.filters import Filter from stix2.datastore.memory import MemorySink, MemorySource +from stix2.utils import parse_into_datetime def test_add_remove_composite_datasource(): @@ -44,14 +45,14 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2): indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001") # In STIX_OBJS2 changed the 'modified' property to a later time... - assert len(indicators) == 2 + assert len(indicators) == 3 cds1.add_data_sources([cds2]) indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001") assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001" - assert indicator["modified"] == "2017-01-31T13:49:53.935Z" + assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z") assert indicator["type"] == "indicator" query1 = [ @@ -68,20 +69,18 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2): # STIX_OBJS2 has indicator with later time, one with different id, one with # original time in STIX_OBJS1 - assert len(results) == 3 + assert len(results) == 4 indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001") assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001" - assert indicator["modified"] == "2017-01-31T13:49:53.935Z" + assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z") assert indicator["type"] == "indicator" - # There is only one indicator with different ID. Since we use the same data - # when deduplicated, only two indicators (one with different modified). results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001") - assert len(results) == 2 + assert len(results) == 3 # Since we have filters already associated with our CompositeSource providing # nothing returns the same as cds1.query(query1) (the associated query is query2) results = cds1.query([]) - assert len(results) == 3 + assert len(results) == 4 diff --git a/stix2/test/test_environment.py b/stix2/test/test_environment.py index d179ae9..a3ec469 100644 --- a/stix2/test/test_environment.py +++ b/stix2/test/test_environment.py @@ -113,7 +113,7 @@ def test_environment_functions(): # Get both versions of the object resp = env.all_versions(INDICATOR_ID) - assert len(resp) == 1 # should be 2, but MemoryStore only keeps 1 version of objects + assert len(resp) == 2 # Get just the most recent version of the object resp = env.get(INDICATOR_ID) diff --git a/stix2/test/test_memory.py b/stix2/test/test_memory.py index 44f90ba..ba70af4 100644 --- a/stix2/test/test_memory.py +++ b/stix2/test/test_memory.py @@ -11,6 +11,7 @@ from stix2.datastore import make_id from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID, IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS, MALWARE_ID, MALWARE_KWARGS, RELATIONSHIP_IDS) +from stix2.utils import parse_into_datetime IND1 = { "created": "2017-01-27T13:49:53.935Z", @@ -167,7 +168,7 @@ def test_memory_store_all_versions(mem_store): type="bundle")) resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001") - assert len(resp) == 1 # MemoryStore can only store 1 version of each object + assert len(resp) == 3 def test_memory_store_query(mem_store): @@ -179,25 +180,27 @@ def test_memory_store_query(mem_store): def test_memory_store_query_single_filter(mem_store): query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001') resp = mem_store.query(query) - assert len(resp) == 1 + assert len(resp) == 2 def test_memory_store_query_empty_query(mem_store): resp = mem_store.query() # sort since returned in random order - resp = sorted(resp, key=lambda k: k['id']) - assert len(resp) == 2 + resp = sorted(resp, key=lambda k: (k['id'], k['modified'])) + assert len(resp) == 3 assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001' - assert resp[0]['modified'] == '2017-01-27T13:49:53.936Z' - assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000002' - assert resp[1]['modified'] == '2017-01-27T13:49:53.935Z' + assert resp[0]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z') + assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000001' + assert resp[1]['modified'] == parse_into_datetime('2017-01-27T13:49:53.936Z') + assert resp[2]['id'] == 'indicator--00000000-0000-4000-8000-000000000002' + assert resp[2]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z') def test_memory_store_query_multiple_filters(mem_store): mem_store.source.filters.add(Filter('type', '=', 'indicator')) query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001') resp = mem_store.query(query) - assert len(resp) == 1 + assert len(resp) == 2 def test_memory_store_save_load_file(mem_store, fs_mem_store): @@ -218,12 +221,8 @@ def test_memory_store_save_load_file(mem_store, fs_mem_store): def test_memory_store_add_invalid_object(mem_store): ind = ('indicator', IND1) # tuple isn't valid - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError): mem_store.add(ind) - assert 'stix_data expected to be' in str(excinfo.value) - assert 'a python-stix2 object' in str(excinfo.value) - assert 'JSON formatted STIX' in str(excinfo.value) - assert 'JSON formatted STIX bundle' in str(excinfo.value) def test_memory_store_object_with_custom_property(mem_store): @@ -246,10 +245,9 @@ def test_memory_store_object_with_custom_property_in_bundle(mem_store): allow_custom=True) bundle = Bundle(camp, allow_custom=True) - mem_store.add(bundle, True) + mem_store.add(bundle) - bundle_r = mem_store.get(bundle.id) - camp_r = bundle_r['objects'][0] + camp_r = mem_store.get(camp.id) assert camp_r.id == camp.id assert camp_r.x_empire == camp.x_empire