diff --git a/stix2/datastore/taxii.py b/stix2/datastore/taxii.py index 1600253..61d2366 100644 --- a/stix2/datastore/taxii.py +++ b/stix2/datastore/taxii.py @@ -300,10 +300,10 @@ class TAXIICollectionSource(DataSource): # The while loop will not be executed if the response is received in full. while envelope.get("more", False): - envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", "")) + envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""), **taxii_filters_dict) all_data.extend(envelope.get("objects", [])) else: - for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page): + for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page, **taxii_filters_dict): all_data.extend(bundle.get("objects", [])) # deduplicate data (before filtering as reduces wasted filtering) diff --git a/stix2/test/v20/test_datastore_taxii.py b/stix2/test/v20/test_datastore_taxii.py index 0b21981..cd051f1 100644 --- a/stix2/test/v20/test_datastore_taxii.py +++ b/stix2/test/v20/test_datastore_taxii.py @@ -5,7 +5,7 @@ import pytest from requests.models import Response import six from taxii2client.common import _filter_kwargs_to_query_params -from taxii2client.v20 import Collection +from taxii2client.v20 import MEDIA_TYPE_STIX_V20, Collection import stix2 from stix2.datastore import DataSourceError @@ -35,12 +35,12 @@ class MockTAXIICollectionEndpoint(Collection): { "date_added": get_timestamp(), "id": object["id"], - "media_type": "application/stix+json;version=2.1", + "media_type": "application/stix+json;version=2.0", "version": object.get("modified", object.get("created", get_timestamp())), }, ) - def get_objects(self, **filter_kwargs): + def get_objects(self, accept=MEDIA_TYPE_STIX_V20, start=0, per_request=0, **filter_kwargs): self._verify_can_read() query_params = _filter_kwargs_to_query_params(filter_kwargs) assert isinstance(query_params, dict) @@ -52,7 +52,10 @@ class MockTAXIICollectionEndpoint(Collection): 100, )[0] if objs: - return stix2.v20.Bundle(objects=objs) + resp = Response() + resp.encoding = "utf-8" + resp._content = six.ensure_binary(stix2.v20.Bundle(objects=objs).serialize(ensure_ascii=False)) + return resp else: resp = Response() resp.status_code = 404