provide pagination support for requests in the TAXIICollectionSource

pull/1/head
Emmanuelle Vargas-Gonzalez 2020-12-21 17:53:53 -05:00
parent 0866df0546
commit ace64c4042
1 changed files with 19 additions and 2 deletions

View File

@ -12,6 +12,8 @@ from stix2.parsing import parse
from stix2.utils import deduplicate
try:
from taxii2client import v20 as tcv20
from taxii2client import v21 as tcv21
from taxii2client.exceptions import ValidationError
_taxii2_client = True
except ImportError:
@ -144,9 +146,12 @@ class TAXIICollectionSource(DataSource):
collection (taxii2.Collection): TAXII Collection instance
allow_custom (bool): Whether to allow custom STIX content to be
added to the FileSystemSink. Default: True
items_per_page (int): How many STIX objects to request per call
to TAXII Server. This value is tunable, but servers may override
if their internal limit is surpassed.
"""
def __init__(self, collection, allow_custom=True):
def __init__(self, collection, allow_custom=True, items_per_page=5000):
super(TAXIICollectionSource, self).__init__()
if not _taxii2_client:
raise ImportError("taxii2client library is required for usage of TAXIICollectionSource")
@ -167,6 +172,7 @@ class TAXIICollectionSource(DataSource):
)
self.allow_custom = allow_custom
self.items_per_page = items_per_page
def get(self, stix_id, version=None, _composite_filters=None):
"""Retrieve STIX object from local/remote STIX Collection
@ -286,8 +292,19 @@ class TAXIICollectionSource(DataSource):
taxii_filters_dict = dict((f.property, f.value) for f in taxii_filters)
# query TAXII collection
all_data = []
try:
all_data = self.collection.get_objects(**taxii_filters_dict).get('objects', [])
if isinstance(self.collection, tcv21.Collection):
envelope = self.collection.get_objects(**taxii_filters_dict)
all_data.extend(envelope.get("objects", []))
# The while loop will not be executed if the response is received in full.
while envelope.get("more", False):
envelope = self.collection.get_objects(limit=self.items_per_page, next=envelope.get("next", ""))
all_data.extend(envelope.get("objects", []))
else:
for bundle in tcv20.as_pages(self.collection.get_objects, per_request=self.items_per_page):
all_data.extend(bundle.get("objects", []))
# deduplicate data (before filtering as reduces wasted filtering)
all_data = deduplicate(all_data)