From 9b8cb09b1ab584017474008088d6ddda17b89f41 Mon Sep 17 00:00:00 2001 From: Emmanuelle Vargas-Gonzalez Date: Tue, 10 Jul 2018 15:43:58 -0400 Subject: [PATCH] Remove 'version' from calls to parse since it is no longer necessary Also, fixed adding STIX2 Bundles to MemorySource. Enhancements to 'save_to_file'. Fix docstrings and encoding support when writing to file. closes #202 --- stix2/datastore/memory.py | 100 ++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/stix2/datastore/memory.py b/stix2/datastore/memory.py index 13f5452..23d8329 100644 --- a/stix2/datastore/memory.py +++ b/stix2/datastore/memory.py @@ -12,17 +12,18 @@ Note: """ +import collections +import io import json import os -from stix2 import Bundle -from stix2.base import _STIXBase +from stix2 import Bundle, v20 from stix2.core import parse from stix2.datastore import DataSink, DataSource, DataStoreMixin from stix2.datastore.filters import Filter, FilterSet, apply_common_filters -def _add(store, stix_data=None, version=None): +def _add(store, stix_data=None): """Add STIX objects to MemoryStore/Sink. Adds STIX objects to an in-memory dictionary for fast lookup. @@ -30,19 +31,13 @@ def _add(store, stix_data=None, version=None): Args: stix_data (list OR dict OR STIX object): STIX objects to be added - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. """ - if isinstance(stix_data, _STIXBase): - # adding a python STIX object - store._data[stix_data['id']] = stix_data - - elif isinstance(stix_data, dict): + if isinstance(stix_data, collections.abc.Mapping): if stix_data['type'] == 'bundle': # adding a json bundle - so just grab STIX objects for stix_obj in stix_data.get('objects', []): - _add(store, stix_obj, version=version) + _add(store, stix_obj) else: # adding a json STIX object store._data[stix_data['id']] = stix_data @@ -50,7 +45,7 @@ def _add(store, stix_data=None, version=None): elif isinstance(stix_data, list): # STIX objects are in a list- recurse on each object for stix_obj in stix_data: - _add(store, stix_obj, version=version) + _add(store, stix_obj) else: raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of)," @@ -71,8 +66,6 @@ class MemoryStore(DataStoreMixin): allow_custom (bool): whether to allow custom STIX content. Only applied when export/input functions called, i.e. load_from_file() and save_to_file(). Defaults to True. - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. Attributes: _data (dict): the in-memory dict that holds STIX objects @@ -80,23 +73,25 @@ class MemoryStore(DataStoreMixin): sink (MemorySink): MemorySink """ - def __init__(self, stix_data=None, allow_custom=True, version=None): + def __init__(self, stix_data=None, allow_custom=True): self._data = {} if stix_data: - _add(self, stix_data, version=version) + _add(self, stix_data) super(MemoryStore, self).__init__( - source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True), - sink=MemorySink(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True) + source=MemorySource(stix_data=self._data, allow_custom=allow_custom, _store=True), + sink=MemorySink(stix_data=self._data, allow_custom=allow_custom, _store=True) ) def save_to_file(self, *args, **kwargs): """Write SITX objects from in-memory dictionary to JSON file, as a STIX - Bundle. + Bundle. If a directory is given, the Bundle 'id' will be used as + filename. Otherwise, the provided value will be used. Args: - file_path (str): file path to write STIX data to + path (str): file path to write STIX data to. + encoding (str): The file encoding. Default utf-8. """ return self.sink.save_to_file(*args, **kwargs) @@ -104,13 +99,11 @@ class MemoryStore(DataStoreMixin): def load_from_file(self, *args, **kwargs): """Load STIX data from JSON file. - File format is expected to be a single JSON - STIX object or JSON STIX bundle. + File format is expected to be a single JSON STIX object or JSON STIX + bundle. Args: - file_path (str): file path to load STIX data from - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. + path (str): file path to load STIX data from """ return self.source.load_from_file(*args, **kwargs) @@ -137,7 +130,7 @@ class MemorySink(DataSink): If part of a MemoryStore, the dict is shared with a MemorySource """ - def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): + def __init__(self, stix_data=None, allow_custom=True, _store=False): super(MemorySink, self).__init__() self._data = {} self.allow_custom = allow_custom @@ -145,19 +138,31 @@ class MemorySink(DataSink): if _store: self._data = stix_data elif stix_data: - _add(self, stix_data, version=version) + _add(self, stix_data) - def add(self, stix_data, version=None): - _add(self, stix_data, version=version) + def add(self, stix_data): + _add(self, stix_data) add.__doc__ = _add.__doc__ - def save_to_file(self, file_path): - file_path = os.path.abspath(file_path) + def save_to_file(self, path, encoding='utf-8'): + path = os.path.abspath(path) + all_objs = list(self._data.values()) - if not os.path.exists(os.path.dirname(file_path)): - os.makedirs(os.path.dirname(file_path)) - with open(file_path, 'w') as f: - f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom))) + if any('spec_version' in x for x in all_objs): + bundle = Bundle(all_objs, allow_custom=self.allow_custom) + else: + bundle = v20.Bundle(all_objs, allow_custom=self.allow_custom) + + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + # if the user only provided a directory, use the bundle id for filename + if os.path.isdir(path): + path = os.path.join(path, bundle['id'] + '.json') + + with io.open(path, 'w', encoding=encoding) as f: + bundle = bundle.serialize(pretty=True, encoding=encoding, ensure_ascii=False) + f.write(bundle) save_to_file.__doc__ = MemoryStore.save_to_file.__doc__ @@ -183,7 +188,7 @@ class MemorySource(DataSource): If part of a MemoryStore, the dict is shared with a MemorySink """ - def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): + def __init__(self, stix_data=None, allow_custom=True, _store=False): super(MemorySource, self).__init__() self._data = {} self.allow_custom = allow_custom @@ -191,7 +196,7 @@ class MemorySource(DataSource): if _store: self._data = stix_data elif stix_data: - _add(self, stix_data, version=version) + _add(self, stix_data) def get(self, stix_id, _composite_filters=None): """Retrieve STIX object from in-memory dict via STIX ID. @@ -230,15 +235,16 @@ class MemorySource(DataSource): return None def all_versions(self, stix_id, _composite_filters=None): - """Retrieve STIX objects from in-memory dict via STIX ID, all versions of it + """Retrieve STIX objects from in-memory dict via STIX ID, all versions + of it. Note: Since Memory sources/sinks don't handle multiple versions of a STIX object, this operation is unnecessary. Translate call to get(). Args: stix_id (str): The STIX ID of the STIX 2 object to retrieve. - _composite_filters (FilterSet): collection of filters passed from the parent - CompositeDataSource, not user supplied + _composite_filters (FilterSet): collection of filters passed from + the parent CompositeDataSource, not user supplied Returns: (list): list of STIX objects that has the supplied ID. As the @@ -259,14 +265,14 @@ class MemorySource(DataSource): Args: query (list): list of filters to search on - _composite_filters (FilterSet): collection of filters passed from the - CompositeDataSource, not user supplied + _composite_filters (FilterSet): collection of filters passed from + the CompositeDataSource, not user supplied Returns: (list): list of STIX objects that matches the supplied - query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory - as they are supplied (either as python dictionary or STIX object), it - is returned in the same form as it as added. + query. As the MemoryStore(i.e. MemorySink) adds STIX objects + to memory as they are supplied (either as python dictionary or + STIX object), it is returned in the same form as it as added. """ query = FilterSet(query) @@ -282,12 +288,12 @@ class MemorySource(DataSource): return all_data - def load_from_file(self, file_path, version=None): + def load_from_file(self, file_path): stix_data = json.load(open(os.path.abspath(file_path), 'r')) if stix_data['type'] == 'bundle': for stix_obj in stix_data['objects']: _add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom)) else: - _add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version)) + _add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom)) load_from_file.__doc__ = MemoryStore.load_from_file.__doc__