Add option for custom content to TAXII datastore

stix2.0
Chris Lenk 2017-10-18 18:34:08 -04:00
parent c6d5eee083
commit 476cd1ed5b
1 changed files with 13 additions and 13 deletions

View File

@ -41,7 +41,7 @@ class TAXIICollectionSink(DataSink):
super(TAXIICollectionSink, self).__init__() super(TAXIICollectionSink, self).__init__()
self.collection = collection self.collection = collection
def add(self, stix_data): def add(self, stix_data, allow_custom=False):
"""add/push STIX content to TAXII Collection endpoint """add/push STIX content to TAXII Collection endpoint
Args: Args:
@ -53,27 +53,27 @@ class TAXIICollectionSink(DataSink):
if isinstance(stix_data, _STIXBase): if isinstance(stix_data, _STIXBase):
# adding python STIX object # adding python STIX object
bundle = dict(Bundle(stix_data)) bundle = dict(Bundle(stix_data, allow_custom=allow_custom))
elif isinstance(stix_data, dict): elif isinstance(stix_data, dict):
# adding python dict (of either Bundle or STIX obj) # adding python dict (of either Bundle or STIX obj)
if stix_data["type"] == "bundle": if stix_data["type"] == "bundle":
bundle = stix_data bundle = stix_data
else: else:
bundle = dict(Bundle(stix_data)) bundle = dict(Bundle(stix_data, allow_custom=allow_custom))
elif isinstance(stix_data, list): elif isinstance(stix_data, list):
# adding list of something - recurse on each # adding list of something - recurse on each
for obj in stix_data: for obj in stix_data:
self.add(obj) self.add(obj, allow_custom=allow_custom)
elif isinstance(stix_data, str): elif isinstance(stix_data, str):
# adding json encoded string of STIX content # adding json encoded string of STIX content
stix_data = parse(stix_data) stix_data = parse(stix_data, allow_custom=allow_custom)
if stix_data["type"] == "bundle": if stix_data["type"] == "bundle":
bundle = dict(stix_data) bundle = dict(stix_data)
else: else:
bundle = dict(Bundle(stix_data)) bundle = dict(Bundle(stix_data, allow_custom=allow_custom))
else: else:
raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle") raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle")
@ -93,7 +93,7 @@ class TAXIICollectionSource(DataSource):
super(TAXIICollectionSource, self).__init__() super(TAXIICollectionSource, self).__init__()
self.collection = collection self.collection = collection
def get(self, stix_id, _composite_filters=None): def get(self, stix_id, _composite_filters=None, allow_custom=False):
"""retrieve STIX object from local/remote STIX Collection """retrieve STIX object from local/remote STIX Collection
endpoint. endpoint.
@ -125,13 +125,13 @@ class TAXIICollectionSource(DataSource):
if len(stix_obj): if len(stix_obj):
stix_obj = stix_obj[0] stix_obj = stix_obj[0]
stix_obj = parse(stix_obj) stix_obj = parse(stix_obj, allow_custom=allow_custom)
else: else:
stix_obj = None stix_obj = None
return stix_obj return stix_obj
def all_versions(self, stix_id, _composite_filters=None): def all_versions(self, stix_id, _composite_filters=None, allow_custom=False):
"""retrieve STIX object from local/remote TAXII Collection """retrieve STIX object from local/remote TAXII Collection
endpoint, all versions of it endpoint, all versions of it
@ -151,11 +151,11 @@ class TAXIICollectionSource(DataSource):
Filter("match[version]", "=", "all") Filter("match[version]", "=", "all")
] ]
all_data = self.query(query=query, _composite_filters=_composite_filters) all_data = self.query(query=query, _composite_filters=_composite_filters, allow_custom=allow_custom)
return all_data return all_data
def query(self, query=None, _composite_filters=None): def query(self, query=None, _composite_filters=None, allow_custom=False):
"""search and retreive STIX objects based on the complete query """search and retreive STIX objects based on the complete query
A "complete query" includes the filters from the query, the filters A "complete query" includes the filters from the query, the filters
@ -194,7 +194,7 @@ class TAXIICollectionSource(DataSource):
taxii_filters = self._parse_taxii_filters(query) taxii_filters = self._parse_taxii_filters(query)
# query TAXII collection # query TAXII collection
all_data = self.collection.get_objects(filters=taxii_filters)["objects"] all_data = self.collection.get_objects(filters=taxii_filters, allow_custom=allow_custom)["objects"]
# deduplicate data (before filtering as reduces wasted filtering) # deduplicate data (before filtering as reduces wasted filtering)
all_data = deduplicate(all_data) all_data = deduplicate(all_data)
@ -203,7 +203,7 @@ class TAXIICollectionSource(DataSource):
all_data = list(apply_common_filters(all_data, query)) all_data = list(apply_common_filters(all_data, query))
# parse python STIX objects from the STIX object dicts # parse python STIX objects from the STIX object dicts
stix_objs = [parse(stix_obj_dict) for stix_obj_dict in all_data] stix_objs = [parse(stix_obj_dict, allow_custom=allow_custom) for stix_obj_dict in all_data]
return stix_objs return stix_objs