Fixed Memory source/sink/store so that it supports multiple versions
of objects. Fixed several bugs too.stix2.0
parent
5a0e102959
commit
d9f6a213c1
|
@ -12,16 +12,17 @@ Note:
|
|||
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
|
||||
from stix2.base import _STIXBase
|
||||
from stix2.core import Bundle, parse
|
||||
from stix2.datastore import DataSink, DataSource, DataStoreMixin
|
||||
from stix2.datastore.filters import Filter, FilterSet, apply_common_filters
|
||||
from stix2.datastore.filters import FilterSet, apply_common_filters
|
||||
|
||||
|
||||
def _add(store, stix_data=None, version=None):
|
||||
def _add(store, stix_data=None, allow_custom=True, version=None):
|
||||
"""Add STIX objects to MemoryStore/Sink.
|
||||
|
||||
Adds STIX objects to an in-memory dictionary for fast lookup.
|
||||
|
@ -33,27 +34,55 @@ def _add(store, stix_data=None, version=None):
|
|||
None, use latest version.
|
||||
|
||||
"""
|
||||
if isinstance(stix_data, _STIXBase):
|
||||
# adding a python STIX object
|
||||
store._data[stix_data["id"]] = stix_data
|
||||
|
||||
elif isinstance(stix_data, dict):
|
||||
if stix_data["type"] == "bundle":
|
||||
# adding a json bundle - so just grab STIX objects
|
||||
for stix_obj in stix_data.get("objects", []):
|
||||
_add(store, stix_obj, version=version)
|
||||
else:
|
||||
# adding a json STIX object
|
||||
store._data[stix_data["id"]] = stix_data
|
||||
|
||||
elif isinstance(stix_data, list):
|
||||
if isinstance(stix_data, list):
|
||||
# STIX objects are in a list- recurse on each object
|
||||
for stix_obj in stix_data:
|
||||
_add(store, stix_obj, version=version)
|
||||
_add(store, stix_obj, allow_custom, version)
|
||||
|
||||
elif stix_data["type"] == "bundle":
|
||||
# adding a json bundle - so just grab STIX objects
|
||||
for stix_obj in stix_data.get("objects", []):
|
||||
_add(store, stix_obj, allow_custom, version)
|
||||
|
||||
else:
|
||||
raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of),"
|
||||
" or a JSON formatted STIX bundle. stix_data was of type: " + str(type(stix_data)))
|
||||
# Adding a single non-bundle object
|
||||
if isinstance(stix_data, _STIXBase):
|
||||
stix_obj = stix_data
|
||||
else:
|
||||
stix_obj = parse(stix_data, allow_custom, version)
|
||||
|
||||
if stix_obj.id in store._data:
|
||||
obj_family = store._data[stix_obj.id]
|
||||
else:
|
||||
obj_family = _ObjectFamily()
|
||||
store._data[stix_obj.id] = obj_family
|
||||
|
||||
obj_family.add(stix_obj)
|
||||
|
||||
|
||||
class _ObjectFamily(object):
|
||||
"""
|
||||
An internal implementation detail of memory sources/sinks/stores.
|
||||
Represents a "family" of STIX objects: all objects with a particular
|
||||
ID. (I.e. all versions.) The latest version is also tracked so that it
|
||||
can be obtained quickly.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.all_versions = {}
|
||||
self.latest_version = None
|
||||
|
||||
def add(self, obj):
|
||||
self.all_versions[obj.modified] = obj
|
||||
if self.latest_version is None or \
|
||||
obj.modified > self.latest_version.modified:
|
||||
self.latest_version = obj
|
||||
|
||||
def __str__(self):
|
||||
return "<<{}; latest={}>>".format(self.all_versions,
|
||||
self.latest_version.modified)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MemoryStore(DataStoreMixin):
|
||||
|
@ -83,7 +112,7 @@ class MemoryStore(DataStoreMixin):
|
|||
self._data = {}
|
||||
|
||||
if stix_data:
|
||||
_add(self, stix_data, version=version)
|
||||
_add(self, stix_data, allow_custom, version=version)
|
||||
|
||||
super(MemoryStore, self).__init__(
|
||||
source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True),
|
||||
|
@ -138,25 +167,31 @@ class MemorySink(DataSink):
|
|||
"""
|
||||
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
||||
super(MemorySink, self).__init__()
|
||||
self._data = {}
|
||||
self.allow_custom = allow_custom
|
||||
|
||||
if _store:
|
||||
self._data = stix_data
|
||||
elif stix_data:
|
||||
_add(self, stix_data, version=version)
|
||||
else:
|
||||
self._data = {}
|
||||
if stix_data:
|
||||
_add(self, stix_data, allow_custom, version=version)
|
||||
|
||||
def add(self, stix_data, version=None):
|
||||
_add(self, stix_data, version=version)
|
||||
_add(self, stix_data, self.allow_custom, version)
|
||||
add.__doc__ = _add.__doc__
|
||||
|
||||
def save_to_file(self, file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
all_objs = itertools.chain.from_iterable(
|
||||
obj_family.all_versions.values()
|
||||
for obj_family in self._data.values()
|
||||
)
|
||||
|
||||
if not os.path.exists(os.path.dirname(file_path)):
|
||||
os.makedirs(os.path.dirname(file_path))
|
||||
with open(file_path, "w") as f:
|
||||
f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom)))
|
||||
f.write(str(Bundle(list(all_objs), allow_custom=self.allow_custom)))
|
||||
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__
|
||||
|
||||
|
||||
|
@ -184,13 +219,14 @@ class MemorySource(DataSource):
|
|||
"""
|
||||
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
||||
super(MemorySource, self).__init__()
|
||||
self._data = {}
|
||||
self.allow_custom = allow_custom
|
||||
|
||||
if _store:
|
||||
self._data = stix_data
|
||||
elif stix_data:
|
||||
_add(self, stix_data, version=version)
|
||||
else:
|
||||
self._data = {}
|
||||
if stix_data:
|
||||
_add(self, stix_data, allow_custom, version=version)
|
||||
|
||||
def get(self, stix_id, _composite_filters=None):
|
||||
"""Retrieve STIX object from in-memory dict via STIX ID.
|
||||
|
@ -207,26 +243,22 @@ class MemorySource(DataSource):
|
|||
is returned in the same form as it as added
|
||||
|
||||
"""
|
||||
if _composite_filters is None:
|
||||
# if get call is only based on 'id', no need to search, just retrieve from dict
|
||||
try:
|
||||
stix_obj = self._data[stix_id]
|
||||
except KeyError:
|
||||
stix_obj = None
|
||||
return stix_obj
|
||||
stix_obj = None
|
||||
object_family = self._data.get(stix_id)
|
||||
if object_family:
|
||||
stix_obj = object_family.latest_version
|
||||
|
||||
# if there are filters from the composite level, process full query
|
||||
query = [Filter("id", "=", stix_id)]
|
||||
if stix_obj:
|
||||
all_filters = list(
|
||||
itertools.chain(
|
||||
_composite_filters or [],
|
||||
self.filters
|
||||
)
|
||||
)
|
||||
|
||||
all_data = self.query(query=query, _composite_filters=_composite_filters)
|
||||
stix_obj = next(apply_common_filters([stix_obj], all_filters), None)
|
||||
|
||||
if all_data:
|
||||
# reduce to most recent version
|
||||
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
|
||||
|
||||
return stix_obj
|
||||
else:
|
||||
return None
|
||||
return stix_obj
|
||||
|
||||
def all_versions(self, stix_id, _composite_filters=None):
|
||||
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
|
||||
|
@ -246,8 +278,23 @@ class MemorySource(DataSource):
|
|||
is returned in the same form as it as added
|
||||
|
||||
"""
|
||||
results = []
|
||||
object_family = self._data.get(stix_id)
|
||||
|
||||
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)]
|
||||
if object_family:
|
||||
all_filters = list(
|
||||
itertools.chain(
|
||||
_composite_filters or [],
|
||||
self.filters
|
||||
)
|
||||
)
|
||||
|
||||
results.extend(
|
||||
apply_common_filters(object_family.all_versions.values(),
|
||||
all_filters)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def query(self, query=None, _composite_filters=None):
|
||||
"""Search and retrieve STIX objects based on the complete query.
|
||||
|
@ -265,7 +312,7 @@ class MemorySource(DataSource):
|
|||
(list): list of STIX objects that matches the supplied
|
||||
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
|
||||
as they are supplied (either as python dictionary or STIX object), it
|
||||
is returned in the same form as it as added.
|
||||
is returned in the same form as it was added.
|
||||
|
||||
"""
|
||||
query = FilterSet(query)
|
||||
|
@ -276,17 +323,23 @@ class MemorySource(DataSource):
|
|||
if _composite_filters:
|
||||
query.add(_composite_filters)
|
||||
|
||||
all_objs = itertools.chain.from_iterable(
|
||||
obj_family.all_versions.values()
|
||||
for obj_family in self._data.values()
|
||||
)
|
||||
|
||||
# Apply STIX common property filters.
|
||||
all_data = list(apply_common_filters(self._data.values(), query))
|
||||
all_data = list(apply_common_filters(all_objs, query))
|
||||
|
||||
return all_data
|
||||
|
||||
def load_from_file(self, file_path, version=None):
|
||||
stix_data = json.load(open(os.path.abspath(file_path), "r"))
|
||||
with open(os.path.abspath(file_path), "r") as f:
|
||||
stix_data = json.load(f)
|
||||
|
||||
# Override user version selection if loading a bundle
|
||||
if stix_data["type"] == "bundle":
|
||||
for stix_obj in stix_data["objects"]:
|
||||
_add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom, version=stix_data["spec_version"]))
|
||||
else:
|
||||
_add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version))
|
||||
version = stix_data["spec_version"]
|
||||
|
||||
_add(self, stix_data, self.allow_custom, version)
|
||||
load_from_file.__doc__ = MemoryStore.load_from_file.__doc__
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
from stix2.datastore import CompositeDataSource, make_id
|
||||
from stix2.datastore.filters import Filter
|
||||
from stix2.datastore.memory import MemorySink, MemorySource
|
||||
from stix2.utils import parse_into_datetime
|
||||
|
||||
|
||||
def test_add_remove_composite_datasource():
|
||||
|
@ -44,14 +45,14 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
|||
indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||
|
||||
# In STIX_OBJS2 changed the 'modified' property to a later time...
|
||||
assert len(indicators) == 2
|
||||
assert len(indicators) == 3
|
||||
|
||||
cds1.add_data_sources([cds2])
|
||||
|
||||
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
||||
|
||||
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
||||
assert indicator["modified"] == "2017-01-31T13:49:53.935Z"
|
||||
assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
|
||||
assert indicator["type"] == "indicator"
|
||||
|
||||
query1 = [
|
||||
|
@ -68,20 +69,18 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
|||
|
||||
# STIX_OBJS2 has indicator with later time, one with different id, one with
|
||||
# original time in STIX_OBJS1
|
||||
assert len(results) == 3
|
||||
assert len(results) == 4
|
||||
|
||||
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
||||
|
||||
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
||||
assert indicator["modified"] == "2017-01-31T13:49:53.935Z"
|
||||
assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
|
||||
assert indicator["type"] == "indicator"
|
||||
|
||||
# There is only one indicator with different ID. Since we use the same data
|
||||
# when deduplicated, only two indicators (one with different modified).
|
||||
results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||
assert len(results) == 2
|
||||
assert len(results) == 3
|
||||
|
||||
# Since we have filters already associated with our CompositeSource providing
|
||||
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
||||
results = cds1.query([])
|
||||
assert len(results) == 3
|
||||
assert len(results) == 4
|
||||
|
|
|
@ -113,7 +113,7 @@ def test_environment_functions():
|
|||
|
||||
# Get both versions of the object
|
||||
resp = env.all_versions(INDICATOR_ID)
|
||||
assert len(resp) == 1 # should be 2, but MemoryStore only keeps 1 version of objects
|
||||
assert len(resp) == 2
|
||||
|
||||
# Get just the most recent version of the object
|
||||
resp = env.get(INDICATOR_ID)
|
||||
|
|
|
@ -11,6 +11,7 @@ from stix2.datastore import make_id
|
|||
from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID,
|
||||
IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS,
|
||||
MALWARE_ID, MALWARE_KWARGS, RELATIONSHIP_IDS)
|
||||
from stix2.utils import parse_into_datetime
|
||||
|
||||
IND1 = {
|
||||
"created": "2017-01-27T13:49:53.935Z",
|
||||
|
@ -167,7 +168,7 @@ def test_memory_store_all_versions(mem_store):
|
|||
type="bundle"))
|
||||
|
||||
resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||
assert len(resp) == 1 # MemoryStore can only store 1 version of each object
|
||||
assert len(resp) == 3
|
||||
|
||||
|
||||
def test_memory_store_query(mem_store):
|
||||
|
@ -179,25 +180,27 @@ def test_memory_store_query(mem_store):
|
|||
def test_memory_store_query_single_filter(mem_store):
|
||||
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
||||
resp = mem_store.query(query)
|
||||
assert len(resp) == 1
|
||||
assert len(resp) == 2
|
||||
|
||||
|
||||
def test_memory_store_query_empty_query(mem_store):
|
||||
resp = mem_store.query()
|
||||
# sort since returned in random order
|
||||
resp = sorted(resp, key=lambda k: k['id'])
|
||||
assert len(resp) == 2
|
||||
resp = sorted(resp, key=lambda k: (k['id'], k['modified']))
|
||||
assert len(resp) == 3
|
||||
assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
|
||||
assert resp[0]['modified'] == '2017-01-27T13:49:53.936Z'
|
||||
assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000002'
|
||||
assert resp[1]['modified'] == '2017-01-27T13:49:53.935Z'
|
||||
assert resp[0]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
|
||||
assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
|
||||
assert resp[1]['modified'] == parse_into_datetime('2017-01-27T13:49:53.936Z')
|
||||
assert resp[2]['id'] == 'indicator--00000000-0000-4000-8000-000000000002'
|
||||
assert resp[2]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
|
||||
|
||||
|
||||
def test_memory_store_query_multiple_filters(mem_store):
|
||||
mem_store.source.filters.add(Filter('type', '=', 'indicator'))
|
||||
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
||||
resp = mem_store.query(query)
|
||||
assert len(resp) == 1
|
||||
assert len(resp) == 2
|
||||
|
||||
|
||||
def test_memory_store_save_load_file(mem_store, fs_mem_store):
|
||||
|
@ -218,12 +221,8 @@ def test_memory_store_save_load_file(mem_store, fs_mem_store):
|
|||
|
||||
def test_memory_store_add_invalid_object(mem_store):
|
||||
ind = ('indicator', IND1) # tuple isn't valid
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
with pytest.raises(TypeError):
|
||||
mem_store.add(ind)
|
||||
assert 'stix_data expected to be' in str(excinfo.value)
|
||||
assert 'a python-stix2 object' in str(excinfo.value)
|
||||
assert 'JSON formatted STIX' in str(excinfo.value)
|
||||
assert 'JSON formatted STIX bundle' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_memory_store_object_with_custom_property(mem_store):
|
||||
|
@ -246,10 +245,9 @@ def test_memory_store_object_with_custom_property_in_bundle(mem_store):
|
|||
allow_custom=True)
|
||||
|
||||
bundle = Bundle(camp, allow_custom=True)
|
||||
mem_store.add(bundle, True)
|
||||
mem_store.add(bundle)
|
||||
|
||||
bundle_r = mem_store.get(bundle.id)
|
||||
camp_r = bundle_r['objects'][0]
|
||||
camp_r = mem_store.get(camp.id)
|
||||
assert camp_r.id == camp.id
|
||||
assert camp_r.x_empire == camp.x_empire
|
||||
|
||||
|
|
Loading…
Reference in New Issue