commit
dbe3f7a000
|
@ -123,8 +123,11 @@ def apply_common_filters(stix_objs, query):
|
||||||
Supports only STIX 2.0 common property properties.
|
Supports only STIX 2.0 common property properties.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stix_objs (list): list of STIX objects to apply the query to
|
stix_objs (iterable): iterable of STIX objects to apply the query to
|
||||||
query (set): set of filters (combined form complete query)
|
query (non-iterator iterable): iterable of filters. Can't be an
|
||||||
|
iterator (e.g. generator iterators won't work), since this is
|
||||||
|
used in an inner loop of a nested loop. So we require the ability
|
||||||
|
to traverse the filters repeatedly.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
STIX objects that successfully evaluate against the query.
|
STIX objects that successfully evaluate against the query.
|
||||||
|
|
|
@ -1,27 +1,18 @@
|
||||||
"""
|
"""
|
||||||
Python STIX 2.0 Memory Source/Sink
|
Python STIX 2.0 Memory Source/Sink
|
||||||
|
|
||||||
TODO:
|
|
||||||
Use deduplicate() calls only when memory corpus is dirty (been added to)
|
|
||||||
can save a lot of time for successive queries
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Not worrying about STIX versioning. The in memory STIX data at anytime
|
|
||||||
will only hold one version of a STIX object. As such, when save() is called,
|
|
||||||
the single versions of all the STIX objects are what is written to file.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from stix2.base import _STIXBase
|
from stix2.base import _STIXBase
|
||||||
from stix2.core import Bundle, parse
|
from stix2.core import Bundle, parse
|
||||||
from stix2.datastore import DataSink, DataSource, DataStoreMixin
|
from stix2.datastore import DataSink, DataSource, DataStoreMixin
|
||||||
from stix2.datastore.filters import Filter, FilterSet, apply_common_filters
|
from stix2.datastore.filters import FilterSet, apply_common_filters
|
||||||
|
|
||||||
|
|
||||||
def _add(store, stix_data=None, version=None):
|
def _add(store, stix_data=None, allow_custom=True, version=None):
|
||||||
"""Add STIX objects to MemoryStore/Sink.
|
"""Add STIX objects to MemoryStore/Sink.
|
||||||
|
|
||||||
Adds STIX objects to an in-memory dictionary for fast lookup.
|
Adds STIX objects to an in-memory dictionary for fast lookup.
|
||||||
|
@ -33,27 +24,77 @@ def _add(store, stix_data=None, version=None):
|
||||||
None, use latest version.
|
None, use latest version.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(stix_data, _STIXBase):
|
if isinstance(stix_data, list):
|
||||||
# adding a python STIX object
|
|
||||||
store._data[stix_data["id"]] = stix_data
|
|
||||||
|
|
||||||
elif isinstance(stix_data, dict):
|
|
||||||
if stix_data["type"] == "bundle":
|
|
||||||
# adding a json bundle - so just grab STIX objects
|
|
||||||
for stix_obj in stix_data.get("objects", []):
|
|
||||||
_add(store, stix_obj, version=version)
|
|
||||||
else:
|
|
||||||
# adding a json STIX object
|
|
||||||
store._data[stix_data["id"]] = stix_data
|
|
||||||
|
|
||||||
elif isinstance(stix_data, list):
|
|
||||||
# STIX objects are in a list- recurse on each object
|
# STIX objects are in a list- recurse on each object
|
||||||
for stix_obj in stix_data:
|
for stix_obj in stix_data:
|
||||||
_add(store, stix_obj, version=version)
|
_add(store, stix_obj, allow_custom, version)
|
||||||
|
|
||||||
|
elif stix_data["type"] == "bundle":
|
||||||
|
# adding a json bundle - so just grab STIX objects
|
||||||
|
for stix_obj in stix_data.get("objects", []):
|
||||||
|
_add(store, stix_obj, allow_custom, version)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of),"
|
# Adding a single non-bundle object
|
||||||
" or a JSON formatted STIX bundle. stix_data was of type: " + str(type(stix_data)))
|
if isinstance(stix_data, _STIXBase):
|
||||||
|
stix_obj = stix_data
|
||||||
|
else:
|
||||||
|
stix_obj = parse(stix_data, allow_custom, version)
|
||||||
|
|
||||||
|
# Map ID directly to the object, if it is a marking. Otherwise,
|
||||||
|
# map to a family, so we can track multiple versions.
|
||||||
|
if _is_marking(stix_obj):
|
||||||
|
store._data[stix_obj.id] = stix_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
if stix_obj.id in store._data:
|
||||||
|
obj_family = store._data[stix_obj.id]
|
||||||
|
else:
|
||||||
|
obj_family = _ObjectFamily()
|
||||||
|
store._data[stix_obj.id] = obj_family
|
||||||
|
|
||||||
|
obj_family.add(stix_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_marking(obj_or_id):
|
||||||
|
"""Determines whether the given object or object ID is/is for a marking
|
||||||
|
definition.
|
||||||
|
|
||||||
|
:param obj_or_id: A STIX object or object ID as a string.
|
||||||
|
:return: True if a marking definition, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(obj_or_id, _STIXBase):
|
||||||
|
id_ = obj_or_id.id
|
||||||
|
else:
|
||||||
|
id_ = obj_or_id
|
||||||
|
|
||||||
|
return id_.startswith("marking-definition--")
|
||||||
|
|
||||||
|
|
||||||
|
class _ObjectFamily(object):
|
||||||
|
"""
|
||||||
|
An internal implementation detail of memory sources/sinks/stores.
|
||||||
|
Represents a "family" of STIX objects: all objects with a particular
|
||||||
|
ID. (I.e. all versions.) The latest version is also tracked so that it
|
||||||
|
can be obtained quickly.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.all_versions = {}
|
||||||
|
self.latest_version = None
|
||||||
|
|
||||||
|
def add(self, obj):
|
||||||
|
self.all_versions[obj.modified] = obj
|
||||||
|
if self.latest_version is None or \
|
||||||
|
obj.modified > self.latest_version.modified:
|
||||||
|
self.latest_version = obj
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "<<{}; latest={}>>".format(self.all_versions,
|
||||||
|
self.latest_version.modified)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore(DataStoreMixin):
|
class MemoryStore(DataStoreMixin):
|
||||||
|
@ -83,7 +124,7 @@ class MemoryStore(DataStoreMixin):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
if stix_data:
|
if stix_data:
|
||||||
_add(self, stix_data, version=version)
|
_add(self, stix_data, allow_custom, version=version)
|
||||||
|
|
||||||
super(MemoryStore, self).__init__(
|
super(MemoryStore, self).__init__(
|
||||||
source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True),
|
source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True),
|
||||||
|
@ -138,25 +179,32 @@ class MemorySink(DataSink):
|
||||||
"""
|
"""
|
||||||
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
||||||
super(MemorySink, self).__init__()
|
super(MemorySink, self).__init__()
|
||||||
self._data = {}
|
|
||||||
self.allow_custom = allow_custom
|
self.allow_custom = allow_custom
|
||||||
|
|
||||||
if _store:
|
if _store:
|
||||||
self._data = stix_data
|
self._data = stix_data
|
||||||
elif stix_data:
|
else:
|
||||||
_add(self, stix_data, version=version)
|
self._data = {}
|
||||||
|
if stix_data:
|
||||||
|
_add(self, stix_data, allow_custom, version=version)
|
||||||
|
|
||||||
def add(self, stix_data, version=None):
|
def add(self, stix_data, version=None):
|
||||||
_add(self, stix_data, version=version)
|
_add(self, stix_data, self.allow_custom, version)
|
||||||
add.__doc__ = _add.__doc__
|
add.__doc__ = _add.__doc__
|
||||||
|
|
||||||
def save_to_file(self, file_path):
|
def save_to_file(self, file_path):
|
||||||
file_path = os.path.abspath(file_path)
|
file_path = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
all_objs = itertools.chain.from_iterable(
|
||||||
|
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||||
|
else [value]
|
||||||
|
for value in self._data.values()
|
||||||
|
)
|
||||||
|
|
||||||
if not os.path.exists(os.path.dirname(file_path)):
|
if not os.path.exists(os.path.dirname(file_path)):
|
||||||
os.makedirs(os.path.dirname(file_path))
|
os.makedirs(os.path.dirname(file_path))
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom)))
|
f.write(str(Bundle(list(all_objs), allow_custom=self.allow_custom)))
|
||||||
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__
|
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,13 +232,14 @@ class MemorySource(DataSource):
|
||||||
"""
|
"""
|
||||||
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False):
|
||||||
super(MemorySource, self).__init__()
|
super(MemorySource, self).__init__()
|
||||||
self._data = {}
|
|
||||||
self.allow_custom = allow_custom
|
self.allow_custom = allow_custom
|
||||||
|
|
||||||
if _store:
|
if _store:
|
||||||
self._data = stix_data
|
self._data = stix_data
|
||||||
elif stix_data:
|
else:
|
||||||
_add(self, stix_data, version=version)
|
self._data = {}
|
||||||
|
if stix_data:
|
||||||
|
_add(self, stix_data, allow_custom, version=version)
|
||||||
|
|
||||||
def get(self, stix_id, _composite_filters=None):
|
def get(self, stix_id, _composite_filters=None):
|
||||||
"""Retrieve STIX object from in-memory dict via STIX ID.
|
"""Retrieve STIX object from in-memory dict via STIX ID.
|
||||||
|
@ -207,26 +256,26 @@ class MemorySource(DataSource):
|
||||||
is returned in the same form as it as added
|
is returned in the same form as it as added
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _composite_filters is None:
|
|
||||||
# if get call is only based on 'id', no need to search, just retrieve from dict
|
|
||||||
try:
|
|
||||||
stix_obj = self._data[stix_id]
|
|
||||||
except KeyError:
|
|
||||||
stix_obj = None
|
stix_obj = None
|
||||||
return stix_obj
|
|
||||||
|
|
||||||
# if there are filters from the composite level, process full query
|
if _is_marking(stix_id):
|
||||||
query = [Filter("id", "=", stix_id)]
|
stix_obj = self._data.get(stix_id)
|
||||||
|
|
||||||
all_data = self.query(query=query, _composite_filters=_composite_filters)
|
|
||||||
|
|
||||||
if all_data:
|
|
||||||
# reduce to most recent version
|
|
||||||
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
|
|
||||||
|
|
||||||
return stix_obj
|
|
||||||
else:
|
else:
|
||||||
return None
|
object_family = self._data.get(stix_id)
|
||||||
|
if object_family:
|
||||||
|
stix_obj = object_family.latest_version
|
||||||
|
|
||||||
|
if stix_obj:
|
||||||
|
all_filters = list(
|
||||||
|
itertools.chain(
|
||||||
|
_composite_filters or [],
|
||||||
|
self.filters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stix_obj = next(apply_common_filters([stix_obj], all_filters), None)
|
||||||
|
|
||||||
|
return stix_obj
|
||||||
|
|
||||||
def all_versions(self, stix_id, _composite_filters=None):
|
def all_versions(self, stix_id, _composite_filters=None):
|
||||||
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
|
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
|
||||||
|
@ -246,8 +295,30 @@ class MemorySource(DataSource):
|
||||||
is returned in the same form as it as added
|
is returned in the same form as it as added
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
results = []
|
||||||
|
stix_objs_to_filter = None
|
||||||
|
if _is_marking(stix_id):
|
||||||
|
stix_obj = self._data.get(stix_id)
|
||||||
|
if stix_obj:
|
||||||
|
stix_objs_to_filter = [stix_obj]
|
||||||
|
else:
|
||||||
|
object_family = self._data.get(stix_id)
|
||||||
|
if object_family:
|
||||||
|
stix_objs_to_filter = object_family.all_versions.values()
|
||||||
|
|
||||||
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)]
|
if stix_objs_to_filter:
|
||||||
|
all_filters = list(
|
||||||
|
itertools.chain(
|
||||||
|
_composite_filters or [],
|
||||||
|
self.filters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results.extend(
|
||||||
|
apply_common_filters(stix_objs_to_filter, all_filters)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def query(self, query=None, _composite_filters=None):
|
def query(self, query=None, _composite_filters=None):
|
||||||
"""Search and retrieve STIX objects based on the complete query.
|
"""Search and retrieve STIX objects based on the complete query.
|
||||||
|
@ -265,7 +336,7 @@ class MemorySource(DataSource):
|
||||||
(list): list of STIX objects that matches the supplied
|
(list): list of STIX objects that matches the supplied
|
||||||
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
|
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
|
||||||
as they are supplied (either as python dictionary or STIX object), it
|
as they are supplied (either as python dictionary or STIX object), it
|
||||||
is returned in the same form as it as added.
|
is returned in the same form as it was added.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
query = FilterSet(query)
|
query = FilterSet(query)
|
||||||
|
@ -276,17 +347,24 @@ class MemorySource(DataSource):
|
||||||
if _composite_filters:
|
if _composite_filters:
|
||||||
query.add(_composite_filters)
|
query.add(_composite_filters)
|
||||||
|
|
||||||
|
all_objs = itertools.chain.from_iterable(
|
||||||
|
value.all_versions.values() if isinstance(value, _ObjectFamily)
|
||||||
|
else [value]
|
||||||
|
for value in self._data.values()
|
||||||
|
)
|
||||||
|
|
||||||
# Apply STIX common property filters.
|
# Apply STIX common property filters.
|
||||||
all_data = list(apply_common_filters(self._data.values(), query))
|
all_data = list(apply_common_filters(all_objs, query))
|
||||||
|
|
||||||
return all_data
|
return all_data
|
||||||
|
|
||||||
def load_from_file(self, file_path, version=None):
|
def load_from_file(self, file_path, version=None):
|
||||||
stix_data = json.load(open(os.path.abspath(file_path), "r"))
|
with open(os.path.abspath(file_path), "r") as f:
|
||||||
|
stix_data = json.load(f)
|
||||||
|
|
||||||
|
# Override user version selection if loading a bundle
|
||||||
if stix_data["type"] == "bundle":
|
if stix_data["type"] == "bundle":
|
||||||
for stix_obj in stix_data["objects"]:
|
version = stix_data["spec_version"]
|
||||||
_add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom, version=stix_data["spec_version"]))
|
|
||||||
else:
|
_add(self, stix_data, self.allow_custom, version)
|
||||||
_add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version))
|
|
||||||
load_from_file.__doc__ = MemoryStore.load_from_file.__doc__
|
load_from_file.__doc__ = MemoryStore.load_from_file.__doc__
|
||||||
|
|
|
@ -2,7 +2,9 @@ import pytest
|
||||||
|
|
||||||
from stix2.datastore import CompositeDataSource, make_id
|
from stix2.datastore import CompositeDataSource, make_id
|
||||||
from stix2.datastore.filters import Filter
|
from stix2.datastore.filters import Filter
|
||||||
from stix2.datastore.memory import MemorySink, MemorySource
|
from stix2.datastore.memory import MemorySink, MemorySource, MemoryStore
|
||||||
|
from stix2.utils import parse_into_datetime
|
||||||
|
from stix2.v20.common import TLP_GREEN
|
||||||
|
|
||||||
|
|
||||||
def test_add_remove_composite_datasource():
|
def test_add_remove_composite_datasource():
|
||||||
|
@ -44,14 +46,14 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
||||||
indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
indicators = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||||
|
|
||||||
# In STIX_OBJS2 changed the 'modified' property to a later time...
|
# In STIX_OBJS2 changed the 'modified' property to a later time...
|
||||||
assert len(indicators) == 2
|
assert len(indicators) == 3
|
||||||
|
|
||||||
cds1.add_data_sources([cds2])
|
cds1.add_data_sources([cds2])
|
||||||
|
|
||||||
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
||||||
|
|
||||||
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
||||||
assert indicator["modified"] == "2017-01-31T13:49:53.935Z"
|
assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
|
||||||
assert indicator["type"] == "indicator"
|
assert indicator["type"] == "indicator"
|
||||||
|
|
||||||
query1 = [
|
query1 = [
|
||||||
|
@ -68,20 +70,80 @@ def test_composite_datasource_operations(stix_objs1, stix_objs2):
|
||||||
|
|
||||||
# STIX_OBJS2 has indicator with later time, one with different id, one with
|
# STIX_OBJS2 has indicator with later time, one with different id, one with
|
||||||
# original time in STIX_OBJS1
|
# original time in STIX_OBJS1
|
||||||
assert len(results) == 3
|
assert len(results) == 4
|
||||||
|
|
||||||
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
indicator = cds1.get("indicator--00000000-0000-4000-8000-000000000001")
|
||||||
|
|
||||||
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
assert indicator["id"] == "indicator--00000000-0000-4000-8000-000000000001"
|
||||||
assert indicator["modified"] == "2017-01-31T13:49:53.935Z"
|
assert indicator["modified"] == parse_into_datetime("2017-01-31T13:49:53.935Z")
|
||||||
assert indicator["type"] == "indicator"
|
assert indicator["type"] == "indicator"
|
||||||
|
|
||||||
# There is only one indicator with different ID. Since we use the same data
|
|
||||||
# when deduplicated, only two indicators (one with different modified).
|
|
||||||
results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
results = cds1.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||||
assert len(results) == 2
|
assert len(results) == 3
|
||||||
|
|
||||||
# Since we have filters already associated with our CompositeSource providing
|
# Since we have filters already associated with our CompositeSource providing
|
||||||
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
# nothing returns the same as cds1.query(query1) (the associated query is query2)
|
||||||
results = cds1.query([])
|
results = cds1.query([])
|
||||||
assert len(results) == 3
|
assert len(results) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_source_markings():
|
||||||
|
msrc = MemorySource(TLP_GREEN)
|
||||||
|
|
||||||
|
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sink_markings():
|
||||||
|
# just make sure there is no crash
|
||||||
|
msink = MemorySink(TLP_GREEN)
|
||||||
|
msink.add(TLP_GREEN)
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_markings():
|
||||||
|
mstore = MemoryStore(TLP_GREEN)
|
||||||
|
|
||||||
|
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_source_mixed(indicator):
|
||||||
|
msrc = MemorySource([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
assert msrc.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert msrc.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert msrc.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
assert msrc.get(indicator.id) == indicator
|
||||||
|
assert msrc.all_versions(indicator.id) == [indicator]
|
||||||
|
assert msrc.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||||
|
|
||||||
|
all_objs = msrc.query()
|
||||||
|
assert TLP_GREEN in all_objs
|
||||||
|
assert indicator in all_objs
|
||||||
|
assert len(all_objs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sink_mixed(indicator):
|
||||||
|
# just make sure there is no crash
|
||||||
|
msink = MemorySink([TLP_GREEN, indicator])
|
||||||
|
msink.add([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_mixed(indicator):
|
||||||
|
mstore = MemoryStore([TLP_GREEN, indicator])
|
||||||
|
|
||||||
|
assert mstore.get(TLP_GREEN.id) == TLP_GREEN
|
||||||
|
assert mstore.all_versions(TLP_GREEN.id) == [TLP_GREEN]
|
||||||
|
assert mstore.query(Filter("id", "=", TLP_GREEN.id)) == [TLP_GREEN]
|
||||||
|
|
||||||
|
assert mstore.get(indicator.id) == indicator
|
||||||
|
assert mstore.all_versions(indicator.id) == [indicator]
|
||||||
|
assert mstore.query(Filter("id", "=", indicator.id)) == [indicator]
|
||||||
|
|
||||||
|
all_objs = mstore.query()
|
||||||
|
assert TLP_GREEN in all_objs
|
||||||
|
assert indicator in all_objs
|
||||||
|
assert len(all_objs) == 2
|
||||||
|
|
|
@ -113,7 +113,7 @@ def test_environment_functions():
|
||||||
|
|
||||||
# Get both versions of the object
|
# Get both versions of the object
|
||||||
resp = env.all_versions(INDICATOR_ID)
|
resp = env.all_versions(INDICATOR_ID)
|
||||||
assert len(resp) == 1 # should be 2, but MemoryStore only keeps 1 version of objects
|
assert len(resp) == 2
|
||||||
|
|
||||||
# Get just the most recent version of the object
|
# Get just the most recent version of the object
|
||||||
resp = env.get(INDICATOR_ID)
|
resp = env.get(INDICATOR_ID)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from stix2 import (Bundle, Campaign, CustomObject, Filter, Identity, Indicator,
|
||||||
Malware, MemorySource, MemoryStore, Relationship,
|
Malware, MemorySource, MemoryStore, Relationship,
|
||||||
properties)
|
properties)
|
||||||
from stix2.datastore import make_id
|
from stix2.datastore import make_id
|
||||||
|
from stix2.utils import parse_into_datetime
|
||||||
|
|
||||||
from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID,
|
from .constants import (CAMPAIGN_ID, CAMPAIGN_KWARGS, IDENTITY_ID,
|
||||||
IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS,
|
IDENTITY_KWARGS, INDICATOR_ID, INDICATOR_KWARGS,
|
||||||
|
@ -167,7 +168,7 @@ def test_memory_store_all_versions(mem_store):
|
||||||
type="bundle"))
|
type="bundle"))
|
||||||
|
|
||||||
resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
resp = mem_store.all_versions("indicator--00000000-0000-4000-8000-000000000001")
|
||||||
assert len(resp) == 1 # MemoryStore can only store 1 version of each object
|
assert len(resp) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_memory_store_query(mem_store):
|
def test_memory_store_query(mem_store):
|
||||||
|
@ -179,25 +180,27 @@ def test_memory_store_query(mem_store):
|
||||||
def test_memory_store_query_single_filter(mem_store):
|
def test_memory_store_query_single_filter(mem_store):
|
||||||
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
||||||
resp = mem_store.query(query)
|
resp = mem_store.query(query)
|
||||||
assert len(resp) == 1
|
assert len(resp) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_memory_store_query_empty_query(mem_store):
|
def test_memory_store_query_empty_query(mem_store):
|
||||||
resp = mem_store.query()
|
resp = mem_store.query()
|
||||||
# sort since returned in random order
|
# sort since returned in random order
|
||||||
resp = sorted(resp, key=lambda k: k['id'])
|
resp = sorted(resp, key=lambda k: (k['id'], k['modified']))
|
||||||
assert len(resp) == 2
|
assert len(resp) == 3
|
||||||
assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
|
assert resp[0]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
|
||||||
assert resp[0]['modified'] == '2017-01-27T13:49:53.936Z'
|
assert resp[0]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
|
||||||
assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000002'
|
assert resp[1]['id'] == 'indicator--00000000-0000-4000-8000-000000000001'
|
||||||
assert resp[1]['modified'] == '2017-01-27T13:49:53.935Z'
|
assert resp[1]['modified'] == parse_into_datetime('2017-01-27T13:49:53.936Z')
|
||||||
|
assert resp[2]['id'] == 'indicator--00000000-0000-4000-8000-000000000002'
|
||||||
|
assert resp[2]['modified'] == parse_into_datetime('2017-01-27T13:49:53.935Z')
|
||||||
|
|
||||||
|
|
||||||
def test_memory_store_query_multiple_filters(mem_store):
|
def test_memory_store_query_multiple_filters(mem_store):
|
||||||
mem_store.source.filters.add(Filter('type', '=', 'indicator'))
|
mem_store.source.filters.add(Filter('type', '=', 'indicator'))
|
||||||
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
query = Filter('id', '=', 'indicator--00000000-0000-4000-8000-000000000001')
|
||||||
resp = mem_store.query(query)
|
resp = mem_store.query(query)
|
||||||
assert len(resp) == 1
|
assert len(resp) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_memory_store_save_load_file(mem_store, fs_mem_store):
|
def test_memory_store_save_load_file(mem_store, fs_mem_store):
|
||||||
|
@ -218,12 +221,8 @@ def test_memory_store_save_load_file(mem_store, fs_mem_store):
|
||||||
|
|
||||||
def test_memory_store_add_invalid_object(mem_store):
|
def test_memory_store_add_invalid_object(mem_store):
|
||||||
ind = ('indicator', IND1) # tuple isn't valid
|
ind = ('indicator', IND1) # tuple isn't valid
|
||||||
with pytest.raises(TypeError) as excinfo:
|
with pytest.raises(TypeError):
|
||||||
mem_store.add(ind)
|
mem_store.add(ind)
|
||||||
assert 'stix_data expected to be' in str(excinfo.value)
|
|
||||||
assert 'a python-stix2 object' in str(excinfo.value)
|
|
||||||
assert 'JSON formatted STIX' in str(excinfo.value)
|
|
||||||
assert 'JSON formatted STIX bundle' in str(excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_store_object_with_custom_property(mem_store):
|
def test_memory_store_object_with_custom_property(mem_store):
|
||||||
|
@ -246,10 +245,9 @@ def test_memory_store_object_with_custom_property_in_bundle(mem_store):
|
||||||
allow_custom=True)
|
allow_custom=True)
|
||||||
|
|
||||||
bundle = Bundle(camp, allow_custom=True)
|
bundle = Bundle(camp, allow_custom=True)
|
||||||
mem_store.add(bundle, True)
|
mem_store.add(bundle)
|
||||||
|
|
||||||
bundle_r = mem_store.get(bundle.id)
|
camp_r = mem_store.get(camp.id)
|
||||||
camp_r = bundle_r['objects'][0]
|
|
||||||
assert camp_r.id == camp.id
|
assert camp_r.id == camp.id
|
||||||
assert camp_r.x_empire == camp.x_empire
|
assert camp_r.x_empire == camp.x_empire
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue