Update Memory datastore to allow for mapping objects

master
Emmanuelle Vargas-Gonzalez 2018-11-01 10:54:58 -04:00
parent 5abe518b8a
commit 8d24015186
1 changed files with 13 additions and 13 deletions

View File

@ -51,14 +51,14 @@ def _add(store, stix_data, allow_custom=True, version=None):
# Map ID directly to the object, if it is a marking. Otherwise, # Map ID directly to the object, if it is a marking. Otherwise,
# map to a family, so we can track multiple versions. # map to a family, so we can track multiple versions.
if _is_marking(stix_obj): if _is_marking(stix_obj):
store._data[stix_obj.id] = stix_obj store._data[stix_obj["id"]] = stix_obj
else: else:
if stix_obj.id in store._data: if stix_obj["id"] in store._data:
obj_family = store._data[stix_obj.id] obj_family = store._data[stix_obj["id"]]
else: else:
obj_family = _ObjectFamily() obj_family = _ObjectFamily()
store._data[stix_obj.id] = obj_family store._data[stix_obj["id"]] = obj_family
obj_family.add(stix_obj) obj_family.add(stix_obj)
@ -71,8 +71,8 @@ def _is_marking(obj_or_id):
:return: True if a marking definition, False otherwise. :return: True if a marking definition, False otherwise.
""" """
if isinstance(obj_or_id, _STIXBase): if isinstance(obj_or_id, (_STIXBase, dict)):
id_ = obj_or_id.id id_ = obj_or_id["id"]
else: else:
id_ = obj_or_id id_ = obj_or_id
@ -91,15 +91,15 @@ class _ObjectFamily(object):
self.latest_version = None self.latest_version = None
def add(self, obj): def add(self, obj):
self.all_versions[obj.modified] = obj self.all_versions[obj["modified"]] = obj
if self.latest_version is None or \ if self.latest_version is None or \
obj.modified > self.latest_version.modified: obj["modified"] > self.latest_version["modified"]:
self.latest_version = obj self.latest_version = obj
def __str__(self): def __str__(self):
return "<<{}; latest={}>>".format( return "<<{}; latest={}>>".format(
self.all_versions, self.all_versions,
self.latest_version.modified, self.latest_version["modified"],
) )
def __repr__(self): def __repr__(self):
@ -199,7 +199,7 @@ class MemorySink(DataSink):
_add(self, stix_data, self.allow_custom) _add(self, stix_data, self.allow_custom)
add.__doc__ = _add.__doc__ add.__doc__ = _add.__doc__
def save_to_file(self, path, encoding='utf-8'): def save_to_file(self, path, encoding="utf-8"):
path = os.path.abspath(path) path = os.path.abspath(path)
all_objs = list(itertools.chain.from_iterable( all_objs = list(itertools.chain.from_iterable(
@ -208,7 +208,7 @@ class MemorySink(DataSink):
for value in self._data.values() for value in self._data.values()
)) ))
if any('spec_version' in x for x in all_objs): if any("spec_version" in x for x in all_objs):
bundle = v21.Bundle(all_objs, allow_custom=self.allow_custom) bundle = v21.Bundle(all_objs, allow_custom=self.allow_custom)
else: else:
bundle = v20.Bundle(all_objs, allow_custom=self.allow_custom) bundle = v20.Bundle(all_objs, allow_custom=self.allow_custom)
@ -218,9 +218,9 @@ class MemorySink(DataSink):
# if the user only provided a directory, use the bundle id for filename # if the user only provided a directory, use the bundle id for filename
if os.path.isdir(path): if os.path.isdir(path):
path = os.path.join(path, bundle['id'] + '.json') path = os.path.join(path, bundle["id"] + ".json")
with io.open(path, 'w', encoding=encoding) as f: with io.open(path, "w", encoding=encoding) as f:
bundle = bundle.serialize(pretty=True, encoding=encoding, ensure_ascii=False) bundle = bundle.serialize(pretty=True, encoding=encoding, ensure_ascii=False)
f.write(bundle) f.write(bundle)
save_to_file.__doc__ = MemoryStore.save_to_file.__doc__ save_to_file.__doc__ = MemoryStore.save_to_file.__doc__