expose **taxii_filters_dict on requests
parent
ace64c4042
commit
76eebeb549
|
@ -300,10 +300,10 @@ class TAXIICollectionSource(DataSource):
|
||||||
|
|
||||||
# The while loop will not be executed if the response is received in full.
|
# The while loop will not be executed if the response is received in full.
|
||||||
while envelope.get("more", False):
|
while envelope.get("more", False):
|
||||||
envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""))
|
envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""), **taxii_filters_dict)
|
||||||
all_data.extend(envelope.get("objects", []))
|
all_data.extend(envelope.get("objects", []))
|
||||||
else:
|
else:
|
||||||
for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page):
|
for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page, **taxii_filters_dict):
|
||||||
all_data.extend(bundle.get("objects", []))
|
all_data.extend(bundle.get("objects", []))
|
||||||
|
|
||||||
# deduplicate data (before filtering as reduces wasted filtering)
|
# deduplicate data (before filtering as reduces wasted filtering)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
from requests.models import Response
|
from requests.models import Response
|
||||||
import six
|
import six
|
||||||
from taxii2client.common import _filter_kwargs_to_query_params
|
from taxii2client.common import _filter_kwargs_to_query_params
|
||||||
from taxii2client.v20 import Collection
|
from taxii2client.v20 import MEDIA_TYPE_STIX_V20, Collection
|
||||||
|
|
||||||
import stix2
|
import stix2
|
||||||
from stix2.datastore import DataSourceError
|
from stix2.datastore import DataSourceError
|
||||||
|
@ -35,12 +35,12 @@ class MockTAXIICollectionEndpoint(Collection):
|
||||||
{
|
{
|
||||||
"date_added": get_timestamp(),
|
"date_added": get_timestamp(),
|
||||||
"id": object["id"],
|
"id": object["id"],
|
||||||
"media_type": "application/stix+json;version=2.1",
|
"media_type": "application/stix+json;version=2.0",
|
||||||
"version": object.get("modified", object.get("created", get_timestamp())),
|
"version": object.get("modified", object.get("created", get_timestamp())),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_objects(self, **filter_kwargs):
|
def get_objects(self, accept=MEDIA_TYPE_STIX_V20, start=0, per_request=0, **filter_kwargs):
|
||||||
self._verify_can_read()
|
self._verify_can_read()
|
||||||
query_params = _filter_kwargs_to_query_params(filter_kwargs)
|
query_params = _filter_kwargs_to_query_params(filter_kwargs)
|
||||||
assert isinstance(query_params, dict)
|
assert isinstance(query_params, dict)
|
||||||
|
@ -52,7 +52,10 @@ class MockTAXIICollectionEndpoint(Collection):
|
||||||
100,
|
100,
|
||||||
)[0]
|
)[0]
|
||||||
if objs:
|
if objs:
|
||||||
return stix2.v20.Bundle(objects=objs)
|
resp = Response()
|
||||||
|
resp.encoding = "utf-8"
|
||||||
|
resp._content = six.ensure_binary(stix2.v20.Bundle(objects=objs).serialize(ensure_ascii=False))
|
||||||
|
return resp
|
||||||
else:
|
else:
|
||||||
resp = Response()
|
resp = Response()
|
||||||
resp.status_code = 404
|
resp.status_code = 404
|
||||||
|
|
Loading…
Reference in New Issue