Separate out Memory datatstore tests

Makes sure custom content can be added to a MemoryStore.
stix2.0
Chris Lenk 2017-10-18 18:31:46 -04:00
parent e1d8c2872e
commit c6d5eee083
3 changed files with 281 additions and 66 deletions

View File

@ -24,8 +24,8 @@ from stix2.sources import DataSink, DataSource, DataStore
from stix2.sources.filters import Filter, apply_common_filters from stix2.sources.filters import Filter, apply_common_filters
def _add(store, stix_data=None): def _add(store, stix_data=None, allow_custom=False):
"""Adds STIX objects to MemoryStore/Sink. """Add STIX objects to MemoryStore/Sink.
Adds STIX objects to an in-memory dictionary for fast lookup. Adds STIX objects to an in-memory dictionary for fast lookup.
Recursive function, breaks down STIX Bundles and lists. Recursive function, breaks down STIX Bundles and lists.
@ -41,35 +41,35 @@ def _add(store, stix_data=None):
elif isinstance(stix_data, dict): elif isinstance(stix_data, dict):
if stix_data["type"] == "bundle": if stix_data["type"] == "bundle":
# adding a json bundle - so just grab STIX objects # adding a json bundle - so just grab STIX objects
for stix_obj in stix_data["objects"]: for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj) _add(store, stix_obj, allow_custom=allow_custom)
else: else:
# adding a json STIX object # adding a json STIX object
store._data[stix_data["id"]] = stix_data store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, str): elif isinstance(stix_data, str):
# adding json encoded string of STIX content # adding json encoded string of STIX content
stix_data = parse(stix_data) stix_data = parse(stix_data, allow_custom=allow_custom)
if stix_data["type"] == "bundle": if stix_data["type"] == "bundle":
# recurse on each STIX object in bundle # recurse on each STIX object in bundle
for stix_obj in stix_data: for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj) _add(store, stix_obj, allow_custom=allow_custom)
else: else:
_add(store, stix_data) _add(store, stix_data)
elif isinstance(stix_data, list): elif isinstance(stix_data, list):
# STIX objects are in a list- recurse on each object # STIX objects are in a list- recurse on each object
for stix_obj in stix_data: for stix_obj in stix_data:
_add(store, stix_obj) _add(store, stix_obj, allow_custom=allow_custom)
else: else:
raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle") raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle")
class MemoryStore(DataStore): class MemoryStore(DataStore):
"""Provides an interface to an in-memory dictionary """Interface to an in-memory dictionary of STIX objects.
of STIX objects. MemoryStore is a wrapper around a paired
MemorySink and MemorySource MemoryStore is a wrapper around a paired MemorySink and MemorySource.
Note: It doesn't make sense to create a MemoryStore by passing Note: It doesn't make sense to create a MemoryStore by passing
in existing MemorySource and MemorySink because there could in existing MemorySource and MemorySink because there could
@ -87,26 +87,25 @@ class MemoryStore(DataStore):
""" """
def __init__(self, stix_data=None): def __init__(self, stix_data=None, allow_custom=False):
super(MemoryStore, self).__init__() super(MemoryStore, self).__init__()
self._data = {} self._data = {}
if stix_data: if stix_data:
_add(self, stix_data) _add(self, stix_data, allow_custom=allow_custom)
self.source = MemorySource(stix_data=self._data, _store=True) self.source = MemorySource(stix_data=self._data, _store=True, allow_custom=allow_custom)
self.sink = MemorySink(stix_data=self._data, _store=True) self.sink = MemorySink(stix_data=self._data, _store=True, allow_custom=allow_custom)
def save_to_file(self, file_path): def save_to_file(self, file_path, allow_custom=False):
return self.sink.save_to_file(file_path=file_path) return self.sink.save_to_file(file_path=file_path, allow_custom=allow_custom)
def load_from_file(self, file_path): def load_from_file(self, file_path, allow_custom=False):
return self.source.load_from_file(file_path=file_path) return self.source.load_from_file(file_path=file_path, allow_custom=allow_custom)
class MemorySink(DataSink): class MemorySink(DataSink):
"""Provides an interface for adding/pushing STIX objects """Interface for adding/pushing STIX objects to an in-memory dictionary.
to an in-memory dictionary.
Designed to be paired with a MemorySource, together as the two Designed to be paired with a MemorySource, together as the two
components of a MemoryStore. components of a MemoryStore.
@ -125,24 +124,24 @@ class MemorySink(DataSink):
a MemorySource a MemorySource
""" """
def __init__(self, stix_data=None, _store=False): def __init__(self, stix_data=None, _store=False, allow_custom=False):
super(MemorySink, self).__init__() super(MemorySink, self).__init__()
self._data = {} self._data = {}
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
_add(self, stix_data) _add(self, stix_data, allow_custom=allow_custom)
def add(self, stix_data): def add(self, stix_data, allow_custom=False):
"""add STIX objects to in-memory dictionary maintained by """add STIX objects to in-memory dictionary maintained by
the MemorySink (MemoryStore) the MemorySink (MemoryStore)
see "_add()" for args documentation see "_add()" for args documentation
""" """
_add(self, stix_data) _add(self, stix_data, allow_custom=allow_custom)
def save_to_file(self, file_path): def save_to_file(self, file_path, allow_custom=False):
"""write SITX objects in in-memory dictionary to json file, as a STIX Bundle """write SITX objects in in-memory dictionary to json file, as a STIX Bundle
Args: Args:
@ -153,12 +152,12 @@ class MemorySink(DataSink):
if not os.path.exists(os.path.dirname(file_path)): if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path)) os.makedirs(os.path.dirname(file_path))
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(str(Bundle(self._data.values()))) f.write(str(Bundle(self._data.values(), allow_custom=allow_custom)))
class MemorySource(DataSource): class MemorySource(DataSource):
"""Provides an interface for searching/retrieving """Interface for searching/retrieving STIX objects from an in-memory
STIX objects from an in-memory dictionary. dictionary.
Designed to be paired with a MemorySink, together as the two Designed to be paired with a MemorySink, together as the two
components of a MemoryStore. components of a MemoryStore.
@ -177,17 +176,17 @@ class MemorySource(DataSource):
a MemorySink a MemorySink
""" """
def __init__(self, stix_data=None, _store=False): def __init__(self, stix_data=None, _store=False, allow_custom=False):
super(MemorySource, self).__init__() super(MemorySource, self).__init__()
self._data = {} self._data = {}
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
_add(self, stix_data) _add(self, stix_data, allow_custom=allow_custom)
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None, allow_custom=False):
"""retrieve STIX object from in-memory dict via STIX ID """Retrieve STIX object from in-memory dict via STIX ID.
Args: Args:
stix_id (str): The STIX ID of the STIX object to be retrieved. stix_id (str): The STIX ID of the STIX object to be retrieved.
@ -200,8 +199,8 @@ class MemorySource(DataSource):
ID. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory ID. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
as they are supplied (either as python dictionary or STIX object), it as they are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added is returned in the same form as it as added
"""
"""
if _composite_filters is None: if _composite_filters is None:
# if get call is only based on 'id', no need to search, just retrieve from dict # if get call is only based on 'id', no need to search, just retrieve from dict
try: try:
@ -213,15 +212,15 @@ class MemorySource(DataSource):
# if there are filters from the composite level, process full query # if there are filters from the composite level, process full query
query = [Filter("id", "=", stix_id)] query = [Filter("id", "=", stix_id)]
all_data = self.query(query=query, _composite_filters=_composite_filters) all_data = self.query(query=query, _composite_filters=_composite_filters, allow_custom=allow_custom)
# reduce to most recent version # reduce to most recent version
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0] stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
return stix_obj return stix_obj
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None, allow_custom=False):
"""retrieve STIX objects from in-memory dict via STIX ID, all versions of it """Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
Note: Since Memory sources/sinks don't handle multiple versions of a Note: Since Memory sources/sinks don't handle multiple versions of a
STIX object, this operation is unnecessary. Translate call to get(). STIX object, this operation is unnecessary. Translate call to get().
@ -239,10 +238,10 @@ class MemorySource(DataSource):
is returned in the same form as it as added is returned in the same form as it as added
""" """
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)] return [self.get(stix_id=stix_id, _composite_filters=_composite_filters, allow_custom=allow_custom)]
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None, allow_custom=False):
"""search and retrieve STIX objects based on the complete query """Search and retrieve STIX objects based on the complete query.
A "complete query" includes the filters from the query, the filters A "complete query" includes the filters from the query, the filters
attached to MemorySource, and any filters passed from a attached to MemorySource, and any filters passed from a
@ -281,15 +280,16 @@ class MemorySource(DataSource):
return all_data return all_data
def load_from_file(self, file_path): def load_from_file(self, file_path, allow_custom=False):
"""load STIX data from json file """Load STIX data from json file.
File format is expected to be a single json File format is expected to be a single json
STIX object or json STIX bundle STIX object or json STIX bundle
Args: Args:
file_path (str): file path to load STIX data from file_path (str): file path to load STIX data from
""" """
file_path = os.path.abspath(file_path) file_path = os.path.abspath(file_path)
stix_data = json.load(open(file_path, "r")) stix_data = json.load(open(file_path, "r"))
_add(self, stix_data) _add(self, stix_data, allow_custom=allow_custom)

View File

@ -1,7 +1,7 @@
import pytest import pytest
from taxii2client import Collection from taxii2client import Collection
from stix2 import Filter, MemorySource, MemoryStore from stix2 import Filter, MemorySource
from stix2.sources import (CompositeDataSource, DataSink, DataSource, from stix2.sources import (CompositeDataSource, DataSink, DataSource,
DataStore, make_id, taxii) DataStore, make_id, taxii)
from stix2.sources.filters import apply_common_filters from stix2.sources.filters import apply_common_filters
@ -144,28 +144,6 @@ def test_ds_abstract_class_smoke():
ds3.query([Filter("id", "=", "malware--fdd60b30-b67c-11e3-b0b9-f01faf20d111")]) ds3.query([Filter("id", "=", "malware--fdd60b30-b67c-11e3-b0b9-f01faf20d111")])
def test_memory_store_smoke():
# Initialize MemoryStore with dict
ms = MemoryStore(STIX_OBJS1)
# Add item to sink
ms.add(dict(id="bundle--%s" % make_id(),
objects=STIX_OBJS2,
spec_version="2.0",
type="bundle"))
resp = ms.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f")
assert len(resp) == 1
resp = ms.get("indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp["id"] == "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f"
query = [Filter('type', '=', 'malware')]
resp = ms.query(query)
assert len(resp) == 0
def test_ds_taxii(collection): def test_ds_taxii(collection):
ds = taxii.TAXIICollectionSource(collection) ds = taxii.TAXIICollectionSource(collection)
assert ds.collection is not None assert ds.collection is not None

237
stix2/test/test_memory.py Normal file
View File

@ -0,0 +1,237 @@
import pytest
from stix2 import (Bundle, Campaign, CustomObject, Filter, MemorySource,
MemoryStore, properties)
from stix2.sources import make_id
IND1 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND2 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND3 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.936Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND4 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND5 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND6 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-31T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND7 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND8 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
STIX_OBJS2 = [IND6, IND7, IND8]
STIX_OBJS1 = [IND1, IND2, IND3, IND4, IND5]
@pytest.fixture
def mem_store():
yield MemoryStore(STIX_OBJS1)
@pytest.fixture
def mem_source():
yield MemorySource(STIX_OBJS1)
def test_memory_source_get(mem_source):
resp = mem_source.get("indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp["id"] == "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f"
def test_memory_source_get_nonexistant_object(mem_source):
resp = mem_source.get("tool--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp is None
def test_memory_store_all_versions(mem_store):
# Add bundle of items to sink
mem_store.add(dict(id="bundle--%s" % make_id(),
objects=STIX_OBJS2,
spec_version="2.0",
type="bundle"))
resp = mem_store.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f")
assert len(resp) == 1 # MemoryStore can only store 1 version of each object
def test_memory_store_query(mem_store):
query = [Filter('type', '=', 'malware')]
resp = mem_store.query(query)
assert len(resp) == 0
def test_memory_store_add_stix_object_str(mem_store):
# add stix object string
camp_id = "campaign--111111b6-1112-4fb0-111b-b111107ca70a"
camp_name = "Aurelius"
camp_alias = "Purple Robes"
camp = """{
"name": "%s",
"type": "campaign",
"objective": "German and French Intelligence Services",
"aliases": ["%s"],
"id": "%s",
"created": "2017-05-31T21:31:53.197755Z"
}""" % (camp_name, camp_alias, camp_id)
mem_store.add(camp)
camp_r = mem_store.get(camp_id)
assert camp_r["id"] == camp_id
assert camp_r["name"] == camp_name
assert camp_alias in camp_r["aliases"]
def test_memory_store_add_stix_bundle_str(mem_store):
# add stix bundle string
camp_id = "campaign--133111b6-1112-4fb0-111b-b111107ca70a"
camp_name = "Atilla"
camp_alias = "Huns"
bund = """{
"type": "bundle",
"id": "bundle--112211b6-1112-4fb0-111b-b111107ca70a",
"spec_version": "2.0",
"objects": [
{
"name": "%s",
"type": "campaign",
"objective": "Bulgarian, Albanian and Romanian Intelligence Services",
"aliases": ["%s"],
"id": "%s",
"created": "2017-05-31T21:31:53.197755Z"
}
]
}""" % (camp_name, camp_alias, camp_id)
mem_store.add(bund)
camp_r = mem_store.get(camp_id)
assert camp_r["id"] == camp_id
assert camp_r["name"] == camp_name
assert camp_alias in camp_r["aliases"]
def test_memory_store_object_with_custom_property(mem_store):
camp = Campaign(name="Scipio Africanus",
objective="Defeat the Carthaginians",
x_empire="Roman",
allow_custom=True)
mem_store.add(camp, True)
camp_r = mem_store.get(camp.id, True)
assert camp_r.id == camp.id
assert camp_r.x_empire == camp.x_empire
def test_memory_store_object_with_custom_property_in_bundle(mem_store):
camp = Campaign(name="Scipio Africanus",
objective="Defeat the Carthaginians",
x_empire="Roman",
allow_custom=True)
bundle = Bundle(camp, allow_custom=True)
mem_store.add(bundle, True)
bundle_r = mem_store.get(bundle.id, True)
camp_r = bundle_r['objects'][0]
assert camp_r.id == camp.id
assert camp_r.x_empire == camp.x_empire
def test_memory_store_custom_object(mem_store):
@CustomObject('x-new-obj', [
('property1', properties.StringProperty(required=True)),
])
class NewObj():
pass
newobj = NewObj(property1='something')
mem_store.add(newobj, True)
newobj_r = mem_store.get(newobj.id, True)
assert newobj_r.id == newobj.id
assert newobj_r.property1 == 'something'