diff --git a/stix2/sources/__init__.py b/stix2/sources/__init__.py index 8f252f7..e25aeb0 100644 --- a/stix2/sources/__init__.py +++ b/stix2/sources/__init__.py @@ -64,34 +64,20 @@ class DataStore(object): An implementer will create a concrete subclass from this abstract class for the specific data store. """ - def __init__(self, name="DataStore", source=None, sink=None): + def __init__(self, name="DataStore"): self.name = name self.id = make_id() - if source: - self.source = source - else: - self.source = DataSource() - if sink: - self.sink = sink - else: - self.sink = DataSink() + self.source = DataSource() + self.sink = DataSink() @property def source(self): return self.source - @source.setter - def source(self, source): - self.source = source - @property def sink(self): return self.sink - @sink.setter - def sink(self, sink): - self.sink = sink - def get(self, stix_id): """ Implement: diff --git a/stix2/sources/filesystem.py b/stix2/sources/filesystem.py index 141f432..a8b6de0 100644 --- a/stix2/sources/filesystem.py +++ b/stix2/sources/filesystem.py @@ -20,44 +20,25 @@ class FileSystemStore(DataStore): """ """ - def __init__(self, stix_dir="stix_data", source=None, sink=None, name="FileSystemStore"): + def __init__(self, stix_dir="stix_data", name="FileSystemStore"): self.name = name self.id = make_id() - - if source: - self.source = source - else: - self.source = FileSystemSource(stix_dir=stix_dir) - - if sink: - self.sink = sink - else: - self.sink = FileSystemSink(stix_dir=stix_dir) + self.source = FileSystemSource(stix_dir=stix_dir) + self.sink = FileSystemSink(stix_dir=stix_dir) @property def source(self): return self.source - @source.setter - def source(self, source): - self.source = source - @property def sink(self): return self.sink - @sink.setter - def sink(self, sink): - self.sink = sink - # file system sink API calls def add(self, stix_objs): return self.sink.add(stix_objs=stix_objs) - def remove(self, stix_ids): - return self.sink.remove(stix_ids=stix_ids) - # file sytem source API calls def get(self, stix_id): @@ -99,17 +80,6 @@ class FileSystemSink(DataSink): path = os.path.join(self.stix_dir, stix_obj["type"], stix_obj["id"]) json.dump(Bundle([stix_obj]), open(path, 'w+', indent=4)) - def remove(self, stix_ids=None): - if not stix_ids: - stix_ids = [] - for stix_id in stix_ids: - stix_type = stix_id.split("--")[0] - try: - os.remove(os.path.join(self.stix_dir, stix_type, stix_id)) - except OSError: - # log error? nonexistent object in data with directory - continue - class FileSystemSource(DataSource): """ diff --git a/stix2/sources/memory.py b/stix2/sources/memory.py index ff231e4..b14e683 100644 --- a/stix2/sources/memory.py +++ b/stix2/sources/memory.py @@ -28,73 +28,16 @@ from stix2validator import validate_string class MemoryStore(DataStore): """ """ - def __init__(self, stix_data=None, source=None, sink=None, name="MemoryStore"): + def __init__(self, stix_data=None, name="MemoryStore"): + """ + Note: It doesnt make sense to create a MemoryStore by passing + in existing MemorySource and MemorySink because there could + be data concurrency issues. Just as easy to create new MemoryStore. + """ self.name = name self.id = make_id() - - if source: - self.source = source - else: - self.source = MemorySource(stix_data=stix_data) - - if sink: - self.sink = sink - else: - self.sink = MemorySink(stix_data=stix_data) - - @property - def source(self): - return self.source - - @source.setter - def source(self, source): - self.source = source - - @property - def sink(self): - return self.sink - - @sink.setter - def sink(self, sink): - self.sink = sink - - # memory sink API calls - - def add(self, stix_data): - return self.sink.add(stix_data=stix_data) - - def remove(self, stix_ids): - return self.sink.remove(stix_ids=stix_ids) - - def save(self): - return self.sink.save() - - # memory source API calls - - def get(self, stix_id): - return self.source.get(stix_id=stix_id) - - def all_versions(self, stix_id): - return self.source.all_versions(stix_id=stix_id) - - def query(self, query): - return self.source.query(query=query) - - -class MemorySink(DataSink): - """ - - """ - def __init__(self, stix_data=None, name="MemorySink"): - """ - Args: - - data (dictionary OR list): valid STIX 2.0 content in bundle or a list - name (string): optional name tag of the data source - - """ - super(MemorySink, self).__init__(name=name) self.data = {} + if stix_data: if type(stix_data) == dict: # stix objects are in a bundle @@ -108,7 +51,6 @@ class MemorySink(DataSink): else: print("Error: json data passed to MemorySink() was found to not be validated by STIX 2 Validator") print(r) - self.data = {} elif type(stix_data) == list: # stix objects are in a list for stix_obj in stix_data: @@ -118,8 +60,86 @@ class MemorySink(DataSink): else: print("Error: STIX object %s is not valid under STIX 2 validator.") % stix_obj["id"] print(r) - else: - raise ValueError("stix_data must be in bundle format or raw list") + + self.source = MemorySource(stix_data=self.data, _store=True) + self.sink = MemorySink(stix_data=self.data, _store=True) + + @property + def source(self): + return self.source + + @property + def sink(self): + return self.sink + + # memory sink API calls + + def add(self, stix_data): + return self.sink.add(stix_data=stix_data) + + def save_to_file(self, file_path): + return self.sink.save(file_path=file_path) + + # memory source API calls + + def get(self, stix_id): + return self.source.get(stix_id=stix_id) + + def all_versions(self, stix_id): + return self.source.all_versions(stix_id=stix_id) + + def query(self, query): + return self.source.query(query=query) + + def load_from_file(self, file_path): + return self.source.load_from_file(file_path=file_path) + + +class MemorySink(DataSink): + """ + + """ + def __init__(self, stix_data=None, name="MemorySink", _store=False): + """ + Args: + + data (dictionary OR list): valid STIX 2.0 content in bundle or a list + name (string): optional name tag of the data source + _store (bool): if the MemorySink is a part of a DataStore, in which case + "stix_data" is a direct reference to shared memory with DataSource + + """ + super(MemorySink, self).__init__(name=name) + + if _store: + self.data = stix_data + else: + self.data = {} + if stix_data: + if type(stix_data) == dict: + # stix objects are in a bundle + # verify STIX json data + r = validate_string(json.dumps(stix_data)) + # make dictionary of the objects for easy lookup + if r.is_valid: + for stix_obj in stix_data["objects"]: + + self.data[stix_obj["id"]] = stix_obj + else: + print("Error: json data passed to MemorySink() was found to not be validated by STIX 2 Validator") + print(r) + self.data = {} + elif type(stix_data) == list: + # stix objects are in a list + for stix_obj in stix_data: + r = validate_string(json.dumps(stix_obj)) + if r.is_valid: + self.data[stix_obj["id"]] = stix_obj + else: + print("Error: STIX object %s is not valid under STIX 2 validator.") % stix_obj["id"] + print(r) + else: + raise ValueError("stix_data must be in bundle format or raw list") def add(self, stix_data): """ @@ -145,60 +165,54 @@ class MemorySink(DataSink): else: raise ValueError("stix_data must be in bundle format or raw list") - def remove(self, stix_ids): + def save_to_file(self, file_path): """ """ - for stix_id in stix_ids: - try: - del self.data[stix_id] - except KeyError: - pass - - def save(self, file_path=None): - """ - """ - if not file_path: - file_path = os.path.dirname(os.path.realpath(__file__)) json.dump(Bundle(self.data.values()), file_path, indent=4) class MemorySource(DataSource): - def __init__(self, stix_data=None, name="MemorySource"): + def __init__(self, stix_data=None, name="MemorySource", _store=False): """ Args: data (dictionary OR list): valid STIX 2.0 content in bundle or list name (string): optional name tag of the data source + _store (bool): if the MemorySource is a part of a DataStore, in which case + "stix_data" is a direct reference to shared memory with DataSink """ super(MemorySource, self).__init__(name=name) - self.data = {} - if stix_data: - if type(stix_data) == dict: - # stix objects are in a bundle - # verify STIX json data - r = validate_string(json.dumps(stix_data)) - # make dictionary of the objects for easy lookup - if r.is_valid: - for stix_obj in stix_data["objects"]: - self.data[stix_obj["id"]] = stix_obj - else: - print("Error: json data passed to MemorySink() was found to not be validated by STIX 2 Validator") - print(r) - self.data = {} - elif type(stix_data) == list: - # stix objects are in a list - for stix_obj in stix_data: - r = validate_string(json.dumps(stix_obj)) + if _store: + self.data = stix_data + else: + self.data = {} + if stix_data: + if type(stix_data) == dict: + # stix objects are in a bundle + # verify STIX json data + r = validate_string(json.dumps(stix_data)) + # make dictionary of the objects for easy lookup if r.is_valid: - self.data[stix_obj["id"]] = stix_obj + for stix_obj in stix_data["objects"]: + self.data[stix_obj["id"]] = stix_obj else: - print("Error: STIX object %s is not valid under STIX 2 validator.") % stix_obj["id"] + print("Error: json data passed to MemorySink() was found to not be validated by STIX 2 Validator") print(r) - else: - raise ValueError("stix_data must be in bundle format or raw list") + self.data = {} + elif type(stix_data) == list: + # stix objects are in a list + for stix_obj in stix_data: + r = validate_string(json.dumps(stix_obj)) + if r.is_valid: + self.data[stix_obj["id"]] = stix_obj + else: + print("Error: STIX object %s is not valid under STIX 2 validator.") % stix_obj["id"] + print(r) + else: + raise ValueError("stix_data must be in bundle format or raw list") def get(self, stix_id, _composite_filters=None): """ @@ -266,3 +280,18 @@ class MemorySource(DataSource): all_data = self.apply_common_filters(self.data.values(), query) return all_data + + def load_from_file(self, file_path): + """ + """ + file_path = os.path.abspath(file_path) + stix_data = json.load(open(file_path, "r")) + + r = validate_string(json.dumps(stix_data)) + + if r.is_valid: + for stix_obj in stix_data["objects"]: + self.data[stix_obj["id"]] = stix_obj + else: + print("Error: STIX data loaded from file (%s) was found to not be validated by STIX 2 Validator") % file_path + print(r) diff --git a/stix2/sources/taxii.py b/stix2/sources/taxii.py index d2669f3..b6aea31 100644 --- a/stix2/sources/taxii.py +++ b/stix2/sources/taxii.py @@ -14,7 +14,6 @@ import json import uuid from stix2.sources import DataSink, DataSource, DataStore, make_id -from taxii2_client import TAXII2Client TAXII_FILTERS = ['added_after', 'id', 'type', 'version'] @@ -23,9 +22,7 @@ class TAXIICollectionStore(DataStore): """ """ def __init__(self, - source=None, - sink=None, - server_uri=None, + taxii_client=None, api_root_name=None, collection_id=None, user=None, @@ -34,33 +31,17 @@ class TAXIICollectionStore(DataStore): self.name = name self.id = make_id() - - if source: - self.source = source - else: - self.source = TAXIICollectionSource(server_uri, api_root_name, collection_id, user, password) - - if sink: - self.sink = sink - else: - self.TAXIICollectionSink(server_uri, api_root_name, collection_id, user, password) + self.source = TAXIICollectionSource(taxii_client, api_root_name, collection_id, user, password) + self.sink = self.TAXIICollectionSink(taxii_client, api_root_name, collection_id, user, password) @property def source(self): return self.source - @source.setter - def source(self, source): - self.source = source - @property def sink(self): return self.sink - @sink.setter - def sink(self, sink): - self.sink = sink - # file system sink API calls def add(self, stix_objs): @@ -82,10 +63,10 @@ class TAXIICollectionSink(DataSink): """ """ - def __init__(self, server_uri=None, api_root_name=None, collection_id=None, user=None, password=None, name="TAXIICollectionSink"): + def __init__(self, taxii_client=None, api_root_name=None, collection_id=None, user=None, password=None, name="TAXIICollectionSink"): super(TAXIICollectionSink, self).__init__(name=name) - self.taxii_client = TAXII2Client(server_uri, user, password) + self.taxii_client = taxii_client self.taxii_client.populate_available_information() if not api_root_name: @@ -110,7 +91,7 @@ class TAXIICollectionSink(DataSink): raise ValueError("The collection %s is not found on the api_root %s of this taxii server" % (collection_id, api_root_name)) - def save(self, stix_obj): + def add(self, stix_obj): """ """ self.collection.add_objects(self.create_bundle([json.loads(str(stix_obj))])) @@ -142,10 +123,10 @@ class TAXIICollectionSink(DataSink): class TAXIICollectionSource(DataSource): """ """ - def __init__(self, server_uri=None, api_root_name=None, collection_id=None, user=None, password=None, name="TAXIICollectionSourc"): + def __init__(self, taxii_client=None, api_root_name=None, collection_id=None, user=None, password=None, name="TAXIICollectionSourc"): super(TAXIICollectionSource, self).__init__(name=name) - self.taxii_client = TAXII2Client(server_uri, user, password) + self.taxii_client = taxii_client self.taxii_client.populate_available_information() if not api_root_name: