Separate out Memory datatstore tests

Makes sure custom content can be added to a MemoryStore.
stix2.0
Chris Lenk 2017-10-18 18:31:46 -04:00
parent e1d8c2872e
commit c6d5eee083
3 changed files with 281 additions and 66 deletions

View File

@ -24,8 +24,8 @@ from stix2.sources import DataSink, DataSource, DataStore
from stix2.sources.filters import Filter, apply_common_filters
def _add(store, stix_data=None):
"""Adds STIX objects to MemoryStore/Sink.
def _add(store, stix_data=None, allow_custom=False):
"""Add STIX objects to MemoryStore/Sink.
Adds STIX objects to an in-memory dictionary for fast lookup.
Recursive function, breaks down STIX Bundles and lists.
@ -41,35 +41,35 @@ def _add(store, stix_data=None):
elif isinstance(stix_data, dict):
if stix_data["type"] == "bundle":
# adding a json bundle - so just grab STIX objects
for stix_obj in stix_data["objects"]:
_add(store, stix_obj)
for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj, allow_custom=allow_custom)
else:
# adding a json STIX object
store._data[stix_data["id"]] = stix_data
elif isinstance(stix_data, str):
# adding json encoded string of STIX content
stix_data = parse(stix_data)
stix_data = parse(stix_data, allow_custom=allow_custom)
if stix_data["type"] == "bundle":
# recurse on each STIX object in bundle
for stix_obj in stix_data:
_add(store, stix_obj)
for stix_obj in stix_data.get("objects", []):
_add(store, stix_obj, allow_custom=allow_custom)
else:
_add(store, stix_data)
elif isinstance(stix_data, list):
# STIX objects are in a list- recurse on each object
for stix_obj in stix_data:
_add(store, stix_obj)
_add(store, stix_obj, allow_custom=allow_custom)
else:
raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle")
class MemoryStore(DataStore):
"""Provides an interface to an in-memory dictionary
of STIX objects. MemoryStore is a wrapper around a paired
MemorySink and MemorySource
"""Interface to an in-memory dictionary of STIX objects.
MemoryStore is a wrapper around a paired MemorySink and MemorySource.
Note: It doesn't make sense to create a MemoryStore by passing
in existing MemorySource and MemorySink because there could
@ -87,26 +87,25 @@ class MemoryStore(DataStore):
"""
def __init__(self, stix_data=None):
def __init__(self, stix_data=None, allow_custom=False):
super(MemoryStore, self).__init__()
self._data = {}
if stix_data:
_add(self, stix_data)
_add(self, stix_data, allow_custom=allow_custom)
self.source = MemorySource(stix_data=self._data, _store=True)
self.sink = MemorySink(stix_data=self._data, _store=True)
self.source = MemorySource(stix_data=self._data, _store=True, allow_custom=allow_custom)
self.sink = MemorySink(stix_data=self._data, _store=True, allow_custom=allow_custom)
def save_to_file(self, file_path):
return self.sink.save_to_file(file_path=file_path)
def save_to_file(self, file_path, allow_custom=False):
return self.sink.save_to_file(file_path=file_path, allow_custom=allow_custom)
def load_from_file(self, file_path):
return self.source.load_from_file(file_path=file_path)
def load_from_file(self, file_path, allow_custom=False):
return self.source.load_from_file(file_path=file_path, allow_custom=allow_custom)
class MemorySink(DataSink):
"""Provides an interface for adding/pushing STIX objects
to an in-memory dictionary.
"""Interface for adding/pushing STIX objects to an in-memory dictionary.
Designed to be paired with a MemorySource, together as the two
components of a MemoryStore.
@ -125,24 +124,24 @@ class MemorySink(DataSink):
a MemorySource
"""
def __init__(self, stix_data=None, _store=False):
def __init__(self, stix_data=None, _store=False, allow_custom=False):
super(MemorySink, self).__init__()
self._data = {}
if _store:
self._data = stix_data
elif stix_data:
_add(self, stix_data)
_add(self, stix_data, allow_custom=allow_custom)
def add(self, stix_data):
def add(self, stix_data, allow_custom=False):
"""add STIX objects to in-memory dictionary maintained by
the MemorySink (MemoryStore)
see "_add()" for args documentation
"""
_add(self, stix_data)
_add(self, stix_data, allow_custom=allow_custom)
def save_to_file(self, file_path):
def save_to_file(self, file_path, allow_custom=False):
"""write SITX objects in in-memory dictionary to json file, as a STIX Bundle
Args:
@ -153,12 +152,12 @@ class MemorySink(DataSink):
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "w") as f:
f.write(str(Bundle(self._data.values())))
f.write(str(Bundle(self._data.values(), allow_custom=allow_custom)))
class MemorySource(DataSource):
"""Provides an interface for searching/retrieving
STIX objects from an in-memory dictionary.
"""Interface for searching/retrieving STIX objects from an in-memory
dictionary.
Designed to be paired with a MemorySink, together as the two
components of a MemoryStore.
@ -177,17 +176,17 @@ class MemorySource(DataSource):
a MemorySink
"""
def __init__(self, stix_data=None, _store=False):
def __init__(self, stix_data=None, _store=False, allow_custom=False):
super(MemorySource, self).__init__()
self._data = {}
if _store:
self._data = stix_data
elif stix_data:
_add(self, stix_data)
_add(self, stix_data, allow_custom=allow_custom)
def get(self, stix_id, _composite_filters=None):
"""retrieve STIX object from in-memory dict via STIX ID
def get(self, stix_id, _composite_filters=None, allow_custom=False):
"""Retrieve STIX object from in-memory dict via STIX ID.
Args:
stix_id (str): The STIX ID of the STIX object to be retrieved.
@ -200,8 +199,8 @@ class MemorySource(DataSource):
ID. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory
as they are supplied (either as python dictionary or STIX object), it
is returned in the same form as it as added
"""
"""
if _composite_filters is None:
# if get call is only based on 'id', no need to search, just retrieve from dict
try:
@ -213,15 +212,15 @@ class MemorySource(DataSource):
# if there are filters from the composite level, process full query
query = [Filter("id", "=", stix_id)]
all_data = self.query(query=query, _composite_filters=_composite_filters)
all_data = self.query(query=query, _composite_filters=_composite_filters, allow_custom=allow_custom)
# reduce to most recent version
stix_obj = sorted(all_data, key=lambda k: k['modified'])[0]
return stix_obj
def all_versions(self, stix_id, _composite_filters=None):
"""retrieve STIX objects from in-memory dict via STIX ID, all versions of it
def all_versions(self, stix_id, _composite_filters=None, allow_custom=False):
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it
Note: Since Memory sources/sinks don't handle multiple versions of a
STIX object, this operation is unnecessary. Translate call to get().
@ -239,10 +238,10 @@ class MemorySource(DataSource):
is returned in the same form as it as added
"""
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters)]
return [self.get(stix_id=stix_id, _composite_filters=_composite_filters, allow_custom=allow_custom)]
def query(self, query=None, _composite_filters=None):
"""search and retrieve STIX objects based on the complete query
def query(self, query=None, _composite_filters=None, allow_custom=False):
"""Search and retrieve STIX objects based on the complete query.
A "complete query" includes the filters from the query, the filters
attached to MemorySource, and any filters passed from a
@ -281,15 +280,16 @@ class MemorySource(DataSource):
return all_data
def load_from_file(self, file_path):
"""load STIX data from json file
def load_from_file(self, file_path, allow_custom=False):
"""Load STIX data from json file.
File format is expected to be a single json
STIX object or json STIX bundle
Args:
file_path (str): file path to load STIX data from
"""
file_path = os.path.abspath(file_path)
stix_data = json.load(open(file_path, "r"))
_add(self, stix_data)
_add(self, stix_data, allow_custom=allow_custom)

View File

@ -1,7 +1,7 @@
import pytest
from taxii2client import Collection
from stix2 import Filter, MemorySource, MemoryStore
from stix2 import Filter, MemorySource
from stix2.sources import (CompositeDataSource, DataSink, DataSource,
DataStore, make_id, taxii)
from stix2.sources.filters import apply_common_filters
@ -144,28 +144,6 @@ def test_ds_abstract_class_smoke():
ds3.query([Filter("id", "=", "malware--fdd60b30-b67c-11e3-b0b9-f01faf20d111")])
def test_memory_store_smoke():
# Initialize MemoryStore with dict
ms = MemoryStore(STIX_OBJS1)
# Add item to sink
ms.add(dict(id="bundle--%s" % make_id(),
objects=STIX_OBJS2,
spec_version="2.0",
type="bundle"))
resp = ms.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f")
assert len(resp) == 1
resp = ms.get("indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp["id"] == "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f"
query = [Filter('type', '=', 'malware')]
resp = ms.query(query)
assert len(resp) == 0
def test_ds_taxii(collection):
ds = taxii.TAXIICollectionSource(collection)
assert ds.collection is not None

237
stix2/test/test_memory.py Normal file
View File

@ -0,0 +1,237 @@
import pytest
from stix2 import (Bundle, Campaign, CustomObject, Filter, MemorySource,
MemoryStore, properties)
from stix2.sources import make_id
IND1 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND2 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND3 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.936Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND4 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND5 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND6 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-31T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND7 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
IND8 = {
"created": "2017-01-27T13:49:53.935Z",
"id": "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f",
"labels": [
"url-watchlist"
],
"modified": "2017-01-27T13:49:53.935Z",
"name": "Malicious site hosting downloader",
"pattern": "[url:value = 'http://x4z9arb.cn/4712']",
"type": "indicator",
"valid_from": "2017-01-27T13:49:53.935382Z"
}
STIX_OBJS2 = [IND6, IND7, IND8]
STIX_OBJS1 = [IND1, IND2, IND3, IND4, IND5]
@pytest.fixture
def mem_store():
yield MemoryStore(STIX_OBJS1)
@pytest.fixture
def mem_source():
yield MemorySource(STIX_OBJS1)
def test_memory_source_get(mem_source):
resp = mem_source.get("indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp["id"] == "indicator--d81f86b8-975b-bc0b-775e-810c5ad45a4f"
def test_memory_source_get_nonexistant_object(mem_source):
resp = mem_source.get("tool--d81f86b8-975b-bc0b-775e-810c5ad45a4f")
assert resp is None
def test_memory_store_all_versions(mem_store):
# Add bundle of items to sink
mem_store.add(dict(id="bundle--%s" % make_id(),
objects=STIX_OBJS2,
spec_version="2.0",
type="bundle"))
resp = mem_store.all_versions("indicator--d81f86b9-975b-bc0b-775e-810c5ad45a4f")
assert len(resp) == 1 # MemoryStore can only store 1 version of each object
def test_memory_store_query(mem_store):
query = [Filter('type', '=', 'malware')]
resp = mem_store.query(query)
assert len(resp) == 0
def test_memory_store_add_stix_object_str(mem_store):
# add stix object string
camp_id = "campaign--111111b6-1112-4fb0-111b-b111107ca70a"
camp_name = "Aurelius"
camp_alias = "Purple Robes"
camp = """{
"name": "%s",
"type": "campaign",
"objective": "German and French Intelligence Services",
"aliases": ["%s"],
"id": "%s",
"created": "2017-05-31T21:31:53.197755Z"
}""" % (camp_name, camp_alias, camp_id)
mem_store.add(camp)
camp_r = mem_store.get(camp_id)
assert camp_r["id"] == camp_id
assert camp_r["name"] == camp_name
assert camp_alias in camp_r["aliases"]
def test_memory_store_add_stix_bundle_str(mem_store):
# add stix bundle string
camp_id = "campaign--133111b6-1112-4fb0-111b-b111107ca70a"
camp_name = "Atilla"
camp_alias = "Huns"
bund = """{
"type": "bundle",
"id": "bundle--112211b6-1112-4fb0-111b-b111107ca70a",
"spec_version": "2.0",
"objects": [
{
"name": "%s",
"type": "campaign",
"objective": "Bulgarian, Albanian and Romanian Intelligence Services",
"aliases": ["%s"],
"id": "%s",
"created": "2017-05-31T21:31:53.197755Z"
}
]
}""" % (camp_name, camp_alias, camp_id)
mem_store.add(bund)
camp_r = mem_store.get(camp_id)
assert camp_r["id"] == camp_id
assert camp_r["name"] == camp_name
assert camp_alias in camp_r["aliases"]
def test_memory_store_object_with_custom_property(mem_store):
camp = Campaign(name="Scipio Africanus",
objective="Defeat the Carthaginians",
x_empire="Roman",
allow_custom=True)
mem_store.add(camp, True)
camp_r = mem_store.get(camp.id, True)
assert camp_r.id == camp.id
assert camp_r.x_empire == camp.x_empire
def test_memory_store_object_with_custom_property_in_bundle(mem_store):
camp = Campaign(name="Scipio Africanus",
objective="Defeat the Carthaginians",
x_empire="Roman",
allow_custom=True)
bundle = Bundle(camp, allow_custom=True)
mem_store.add(bundle, True)
bundle_r = mem_store.get(bundle.id, True)
camp_r = bundle_r['objects'][0]
assert camp_r.id == camp.id
assert camp_r.x_empire == camp.x_empire
def test_memory_store_custom_object(mem_store):
@CustomObject('x-new-obj', [
('property1', properties.StringProperty(required=True)),
])
class NewObj():
pass
newobj = NewObj(property1='something')
mem_store.add(newobj, True)
newobj_r = mem_store.get(newobj.id, True)
assert newobj_r.id == newobj.id
assert newobj_r.property1 == 'something'