diff --git a/stix2/datastore/taxii.py b/stix2/datastore/taxii.py index 2a8d5cb..7547bd7 100644 --- a/stix2/datastore/taxii.py +++ b/stix2/datastore/taxii.py @@ -1,7 +1,7 @@ """Python STIX 2.x TAXIICollectionStore""" from requests.exceptions import HTTPError -from stix2 import Bundle +from stix2 import Bundle, v20 from stix2.base import _STIXBase from stix2.core import parse from stix2.datastore import (DataSink, DataSource, DataSourceError, @@ -74,44 +74,52 @@ class TAXIICollectionSink(DataSink): self.allow_custom = allow_custom - def add(self, stix_data, version=None): + def add(self, stix_data): """Add/push STIX content to TAXII Collection endpoint Args: - stix_data (STIX object OR dict OR str OR list): valid STIX 2.0 content - in a STIX object (or Bundle), STIX onject dict (or Bundle dict), or a STIX 2.0 - json encoded string, or list of any of the following - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. + stix_data (STIX object OR dict OR str OR list): valid STIX2 + content in a STIX object (or Bundle), STIX object dict (or + Bundle dict), or a STIX2 json encoded string, or list of + any of the following. """ if isinstance(stix_data, _STIXBase): # adding python STIX object if stix_data['type'] == 'bundle': - bundle = stix_data.serialize(encoding='utf-8') + bundle = stix_data.serialize(encoding='utf-8', ensure_ascii=False) + elif 'spec_version' in stix_data: + # If the spec_version is present, use new Bundle object... + bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) else: - bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8') + bundle = v20.Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) elif isinstance(stix_data, dict): # adding python dict (of either Bundle or STIX obj) if stix_data['type'] == 'bundle': - bundle = parse(stix_data, allow_custom=self.allow_custom, version=version).serialize(encoding='utf-8') + bundle = parse(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) + elif 'spec_version' in stix_data: + # If the spec_version is present, use new Bundle object... + bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) else: - bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8') + bundle = v20.Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) elif isinstance(stix_data, list): # adding list of something - recurse on each for obj in stix_data: - self.add(obj, version=version) + self.add(obj) return elif isinstance(stix_data, str): # adding json encoded string of STIX content - stix_data = parse(stix_data, allow_custom=self.allow_custom, version=version) + stix_data = parse(stix_data, allow_custom=self.allow_custom) if stix_data['type'] == 'bundle': - bundle = stix_data.serialize(encoding='utf-8') + bundle = stix_data.serialize(encoding='utf-8', ensure_ascii=False) + elif 'spec_version' in stix_data: + # If the spec_version is present, use new Bundle object... + bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) else: - bundle = Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8') + bundle = v20.Bundle(stix_data, allow_custom=self.allow_custom).serialize(encoding='utf-8', ensure_ascii=False) else: raise TypeError("stix_data must be as STIX object(or list of),json formatted STIX (or list of), or a json formatted STIX bundle") @@ -147,16 +155,14 @@ class TAXIICollectionSource(DataSource): self.allow_custom = allow_custom - def get(self, stix_id, version=None, _composite_filters=None): + def get(self, stix_id, _composite_filters=None): """Retrieve STIX object from local/remote STIX Collection endpoint. Args: stix_id (str): The STIX ID of the STIX object to be retrieved. - _composite_filters (FilterSet): collection of filters passed from the parent - CompositeDataSource, not user supplied - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. + _composite_filters (FilterSet): collection of filters passed from + the parent CompositeDataSource, not user supplied Returns: (STIX object): STIX object that has the supplied STIX ID. @@ -172,21 +178,22 @@ class TAXIICollectionSource(DataSource): if _composite_filters: query.add(_composite_filters) - # dont extract TAXII filters from query (to send to TAXII endpoint) - # as directly retrieveing a STIX object by ID + # don't extract TAXII filters from query (to send to TAXII endpoint) + # as directly retrieving a STIX object by ID try: stix_objs = self.collection.get_object(stix_id)['objects'] stix_obj = list(apply_common_filters(stix_objs, query)) except HTTPError as e: if e.response.status_code == 404: - # if resource not found or access is denied from TAXII server, return None + # if resource not found or access is denied from TAXII server, + # return None stix_obj = [] else: raise DataSourceError("TAXII Collection resource returned error", e) if len(stix_obj): - stix_obj = parse(stix_obj[0], allow_custom=self.allow_custom, version=version) + stix_obj = parse(stix_obj[0], allow_custom=self.allow_custom) if stix_obj.id != stix_id: # check - was added to handle erroneous TAXII servers stix_obj = None @@ -195,7 +202,7 @@ class TAXIICollectionSource(DataSource): return stix_obj - def all_versions(self, stix_id, version=None, _composite_filters=None): + def all_versions(self, stix_id, _composite_filters=None): """Retrieve STIX object from local/remote TAXII Collection endpoint, all versions of it @@ -203,8 +210,6 @@ class TAXIICollectionSource(DataSource): stix_id (str): The STIX ID of the STIX objects to be retrieved. _composite_filters (FilterSet): collection of filters passed from the parent CompositeDataSource, not user supplied - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. Returns: (see query() as all_versions() is just a wrapper) @@ -219,14 +224,14 @@ class TAXIICollectionSource(DataSource): all_data = self.query(query=query, _composite_filters=_composite_filters) # parse STIX objects from TAXII returned json - all_data = [parse(stix_obj, allow_custom=self.allow_custom, version=version) for stix_obj in all_data] + all_data = [parse(stix_obj, allow_custom=self.allow_custom) for stix_obj in all_data] # check - was added to handle erroneous TAXII servers all_data_clean = [stix_obj for stix_obj in all_data if stix_obj.id == stix_id] return all_data_clean - def query(self, query=None, version=None, _composite_filters=None): + def query(self, query=None, _composite_filters=None): """Search and retreive STIX objects based on the complete query A "complete query" includes the filters from the query, the filters @@ -235,10 +240,8 @@ class TAXIICollectionSource(DataSource): Args: query (list): list of filters to search on - _composite_filters (FilterSet): collection of filters passed from the - CompositeDataSource, not user supplied - version (str): Which STIX2 version to use. (e.g. "2.0", "2.1"). If - None, use latest version. + _composite_filters (FilterSet): collection of filters passed from + the CompositeDataSource, not user supplied Returns: (list): list of STIX objects that matches the supplied @@ -279,7 +282,7 @@ class TAXIICollectionSource(DataSource): " denied. Received error: ", e) # parse python STIX objects from the STIX object dicts - stix_objs = [parse(stix_obj_dict, allow_custom=self.allow_custom, version=version) for stix_obj_dict in all_data] + stix_objs = [parse(stix_obj_dict, allow_custom=self.allow_custom) for stix_obj_dict in all_data] return stix_objs @@ -291,16 +294,15 @@ class TAXIICollectionSource(DataSource): Notes: Currently, the TAXII2Client can handle TAXII filters where the - filter value is list, as both a comma-seperated string or python list + filter value is list, as both a comma-seperated string or python + list. For instance - "?match[type]=indicator,sighting" can be in a filter in any of these formats: Filter("type", "", "indicator,sighting") - Filter("type", "", ["indicator", "sighting"]) - Args: query (list): list of filters to extract which ones are TAXII specific.