Remove 'version' from calls to parse since it is no longer necessary

Also, fixed adding STIX2 Bundles to MemorySource. Enhancements to 'save_to_file'. Fix docstrings and encoding support when writing to file. closes #202
stix2.1
Emmanuelle Vargas-Gonzalez 2018-07-10 15:43:58 -04:00
parent b6fefc52d9
commit 9b8cb09b1a
1 changed files with 53 additions and 47 deletions

View File

@ -12,17 +12,18 @@ Note:
""" """
import collections
import io
import json import json
import os import os
from stix2 import Bundle from stix2 import Bundle, v20
from stix2.base import _STIXBase
from stix2.core import parse from stix2.core import parse
from stix2.datastore import DataSink, DataSource, DataStoreMixin from stix2.datastore import DataSink, DataSource, DataStoreMixin
from stix2.datastore.filters import Filter, FilterSet, apply_common_filters from stix2.datastore.filters import Filter, FilterSet, apply_common_filters
def _add(store, stix_data=None, version=None): def _add(store, stix_data=None):
"""Add STIX objects to MemoryStore/Sink. """Add STIX objects to MemoryStore/Sink.
Adds STIX objects to an in-memory dictionary for fast lookup. Adds STIX objects to an in-memory dictionary for fast lookup.
@ -30,19 +31,13 @@ def _add(store, stix_data=None, version=None):
Args: Args:
stix_data (list OR dict OR STIX object): STIX objects to be added stix_data (list OR dict OR STIX object): STIX objects to be added
version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If
None, use latest version.
""" """
if isinstance(stix_data, _STIXBase): if isinstance(stix_data, collections.abc.Mapping):
# adding a python STIX object
store._data[stix_data['id']] = stix_data
elif isinstance(stix_data, dict):
if stix_data['type'] == 'bundle': if stix_data['type'] == 'bundle':
# adding a json bundle - so just grab STIX objects # adding a json bundle - so just grab STIX objects
for stix_obj in stix_data.get('objects', []): for stix_obj in stix_data.get('objects', []):
_add(store, stix_obj, version=version) _add(store, stix_obj)
else: else:
# adding a json STIX object # adding a json STIX object
store._data[stix_data['id']] = stix_data store._data[stix_data['id']] = stix_data
@ -50,7 +45,7 @@ def _add(store, stix_data=None, version=None):
elif isinstance(stix_data, list): elif isinstance(stix_data, list):
# STIX objects are in a list- recurse on each object # STIX objects are in a list- recurse on each object
for stix_obj in stix_data: for stix_obj in stix_data:
_add(store, stix_obj, version=version) _add(store, stix_obj)
else: else:
raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of)," raise TypeError("stix_data expected to be a python-stix2 object (or list of), JSON formatted STIX (or list of),"
@ -71,8 +66,6 @@ class MemoryStore(DataStoreMixin):
allow_custom (bool): whether to allow custom STIX content. allow_custom (bool): whether to allow custom STIX content.
Only applied when export/input functions called, i.e. Only applied when export/input functions called, i.e.
load_from_file() and save_to_file(). Defaults to True. load_from_file() and save_to_file(). Defaults to True.
version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If
None, use latest version.
Attributes: Attributes:
_data (dict): the in-memory dict that holds STIX objects _data (dict): the in-memory dict that holds STIX objects
@ -80,23 +73,25 @@ class MemoryStore(DataStoreMixin):
sink (MemorySink): MemorySink sink (MemorySink): MemorySink
""" """
def __init__(self, stix_data=None, allow_custom=True, version=None): def __init__(self, stix_data=None, allow_custom=True):
self._data = {} self._data = {}
if stix_data: if stix_data:
_add(self, stix_data, version=version) _add(self, stix_data)
super(MemoryStore, self).__init__( super(MemoryStore, self).__init__(
source=MemorySource(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True), source=MemorySource(stix_data=self._data, allow_custom=allow_custom, _store=True),
sink=MemorySink(stix_data=self._data, allow_custom=allow_custom, version=version, _store=True) sink=MemorySink(stix_data=self._data, allow_custom=allow_custom, _store=True)
) )
def save_to_file(self, *args, **kwargs): def save_to_file(self, *args, **kwargs):
"""Write SITX objects from in-memory dictionary to JSON file, as a STIX """Write SITX objects from in-memory dictionary to JSON file, as a STIX
Bundle. Bundle. If a directory is given, the Bundle 'id' will be used as
filename. Otherwise, the provided value will be used.
Args: Args:
file_path (str): file path to write STIX data to path (str): file path to write STIX data to.
encoding (str): The file encoding. Default utf-8.
""" """
return self.sink.save_to_file(*args, **kwargs) return self.sink.save_to_file(*args, **kwargs)
@ -104,13 +99,11 @@ class MemoryStore(DataStoreMixin):
def load_from_file(self, *args, **kwargs): def load_from_file(self, *args, **kwargs):
"""Load STIX data from JSON file. """Load STIX data from JSON file.
File format is expected to be a single JSON File format is expected to be a single JSON STIX object or JSON STIX
STIX object or JSON STIX bundle. bundle.
Args: Args:
file_path (str): file path to load STIX data from path (str): file path to load STIX data from
version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If
None, use latest version.
""" """
return self.source.load_from_file(*args, **kwargs) return self.source.load_from_file(*args, **kwargs)
@ -137,7 +130,7 @@ class MemorySink(DataSink):
If part of a MemoryStore, the dict is shared with a MemorySource If part of a MemoryStore, the dict is shared with a MemorySource
""" """
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): def __init__(self, stix_data=None, allow_custom=True, _store=False):
super(MemorySink, self).__init__() super(MemorySink, self).__init__()
self._data = {} self._data = {}
self.allow_custom = allow_custom self.allow_custom = allow_custom
@ -145,19 +138,31 @@ class MemorySink(DataSink):
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
_add(self, stix_data, version=version) _add(self, stix_data)
def add(self, stix_data, version=None): def add(self, stix_data):
_add(self, stix_data, version=version) _add(self, stix_data)
add.__doc__ = _add.__doc__ add.__doc__ = _add.__doc__
def save_to_file(self, file_path): def save_to_file(self, path, encoding='utf-8'):
file_path = os.path.abspath(file_path) path = os.path.abspath(path)
all_objs = list(self._data.values())
if not os.path.exists(os.path.dirname(file_path)): if any('spec_version' in x for x in all_objs):
os.makedirs(os.path.dirname(file_path)) bundle = Bundle(all_objs, allow_custom=self.allow_custom)
with open(file_path, 'w') as f: else:
f.write(str(Bundle(list(self._data.values()), allow_custom=self.allow_custom))) bundle = v20.Bundle(all_objs, allow_custom=self.allow_custom)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
# if the user only provided a directory, use the bundle id for filename
if os.path.isdir(path):
path = os.path.join(path, bundle['id'] + '.json')
with io.open(path, 'w', encoding=encoding) as f:
bundle = bundle.serialize(pretty=True, encoding=encoding, ensure_ascii=False)
f.write(bundle)
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__ save_to_file.__doc__ = MemoryStore.save_to_file.__doc__
@ -183,7 +188,7 @@ class MemorySource(DataSource):
If part of a MemoryStore, the dict is shared with a MemorySink If part of a MemoryStore, the dict is shared with a MemorySink
""" """
def __init__(self, stix_data=None, allow_custom=True, version=None, _store=False): def __init__(self, stix_data=None, allow_custom=True, _store=False):
super(MemorySource, self).__init__() super(MemorySource, self).__init__()
self._data = {} self._data = {}
self.allow_custom = allow_custom self.allow_custom = allow_custom
@ -191,7 +196,7 @@ class MemorySource(DataSource):
if _store: if _store:
self._data = stix_data self._data = stix_data
elif stix_data: elif stix_data:
_add(self, stix_data, version=version) _add(self, stix_data)
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None):
"""Retrieve STIX object from in-memory dict via STIX ID. """Retrieve STIX object from in-memory dict via STIX ID.
@ -230,15 +235,16 @@ class MemorySource(DataSource):
return None return None
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None):
"""Retrieve STIX objects from in-memory dict via STIX ID, all versions of it """Retrieve STIX objects from in-memory dict via STIX ID, all versions
of it.
Note: Since Memory sources/sinks don't handle multiple versions of a Note: Since Memory sources/sinks don't handle multiple versions of a
STIX object, this operation is unnecessary. Translate call to get(). STIX object, this operation is unnecessary. Translate call to get().
Args: Args:
stix_id (str): The STIX ID of the STIX 2 object to retrieve. stix_id (str): The STIX ID of the STIX 2 object to retrieve.
_composite_filters (FilterSet): collection of filters passed from the parent _composite_filters (FilterSet): collection of filters passed from
CompositeDataSource, not user supplied the parent CompositeDataSource, not user supplied
Returns: Returns:
(list): list of STIX objects that has the supplied ID. As the (list): list of STIX objects that has the supplied ID. As the
@ -259,14 +265,14 @@ class MemorySource(DataSource):
Args: Args:
query (list): list of filters to search on query (list): list of filters to search on
_composite_filters (FilterSet): collection of filters passed from the _composite_filters (FilterSet): collection of filters passed from
CompositeDataSource, not user supplied the CompositeDataSource, not user supplied
Returns: Returns:
(list): list of STIX objects that matches the supplied (list): list of STIX objects that matches the supplied
query. As the MemoryStore(i.e. MemorySink) adds STIX objects to memory query. As the MemoryStore(i.e. MemorySink) adds STIX objects
as they are supplied (either as python dictionary or STIX object), it to memory as they are supplied (either as python dictionary or
is returned in the same form as it as added. STIX object), it is returned in the same form as it as added.
""" """
query = FilterSet(query) query = FilterSet(query)
@ -282,12 +288,12 @@ class MemorySource(DataSource):
return all_data return all_data
def load_from_file(self, file_path, version=None): def load_from_file(self, file_path):
stix_data = json.load(open(os.path.abspath(file_path), 'r')) stix_data = json.load(open(os.path.abspath(file_path), 'r'))
if stix_data['type'] == 'bundle': if stix_data['type'] == 'bundle':
for stix_obj in stix_data['objects']: for stix_obj in stix_data['objects']:
_add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom)) _add(self, stix_data=parse(stix_obj, allow_custom=self.allow_custom))
else: else:
_add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom, version=version)) _add(self, stix_data=parse(stix_data, allow_custom=self.allow_custom))
load_from_file.__doc__ = MemoryStore.load_from_file.__doc__ load_from_file.__doc__ = MemoryStore.load_from_file.__doc__