expose **taxii_filters_dict on requests

pull/1/head
Emmanuelle Vargas-Gonzalez 2020-12-22 16:52:27 -05:00
parent ace64c4042
commit 76eebeb549
2 changed files with 9 additions and 6 deletions

View File

@ -300,10 +300,10 @@ class TAXIICollectionSource(DataSource):
# The while loop will not be executed if the response is received in full.
while envelope.get("more", False):
envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""))
envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""), **taxii_filters_dict)
all_data.extend(envelope.get("objects", []))
else:
for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page):
for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page, **taxii_filters_dict):
all_data.extend(bundle.get("objects", []))
# deduplicate data (before filtering as reduces wasted filtering)

View File

@ -5,7 +5,7 @@ import pytest
from requests.models import Response
import six
from taxii2client.common import _filter_kwargs_to_query_params
from taxii2client.v20 import Collection
from taxii2client.v20 import MEDIA_TYPE_STIX_V20, Collection
import stix2
from stix2.datastore import DataSourceError
@ -35,12 +35,12 @@ class MockTAXIICollectionEndpoint(Collection):
{
"date_added": get_timestamp(),
"id": object["id"],
"media_type": "application/stix+json;version=2.1",
"media_type": "application/stix+json;version=2.0",
"version": object.get("modified", object.get("created", get_timestamp())),
},
)
def get_objects(self, **filter_kwargs):
def get_objects(self, accept=MEDIA_TYPE_STIX_V20, start=0, per_request=0, **filter_kwargs):
self._verify_can_read()
query_params = _filter_kwargs_to_query_params(filter_kwargs)
assert isinstance(query_params, dict)
@ -52,7 +52,10 @@ class MockTAXIICollectionEndpoint(Collection):
100,
)[0]
if objs:
return stix2.v20.Bundle(objects=objs)
resp = Response()
resp.encoding = "utf-8"
resp._content = six.ensure_binary(stix2.v20.Bundle(objects=objs).serialize(ensure_ascii=False))
return resp
else:
resp = Response()
resp.status_code = 404