diff --git a/stix2/datastore/__init__.py b/stix2/datastore/__init__.py index 4e53b17..e0de6fe 100644 --- a/stix2/datastore/__init__.py +++ b/stix2/datastore/__init__.py @@ -359,7 +359,7 @@ class DataSource(with_metaclass(ABCMeta)): return results - def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=[]): + def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=None): """Retrieve STIX Objects that have a Relationship involving the given STIX object. diff --git a/stix2/datastore/filters.py b/stix2/datastore/filters.py index a116507..10bbeee 100644 --- a/stix2/datastore/filters.py +++ b/stix2/datastore/filters.py @@ -44,27 +44,35 @@ def _check_filter_components(prop, op, value): return True -def _assemble_filters(filter_arg, filters=[]): +def _assemble_filters(filters1=None, filters2=None): """Assemble a list of filters. This can be used to allow certain functions to work correctly no matter if the user provides a single filter or a list of them. Args: - filter_arg (Filter or list): The single Filter or list of Filters to be - coerced into a list of Filters. - filters (list, optional): A list of Filters to be automatically appended. + filters1 (Filter or list, optional): The single Filter or list of Filters to + coerce into a list of Filters. + filters2 (Filter or list, optional): The single Filter or list of Filters to + append to the list of Filters. Returns: List of Filters. """ - if isinstance(filter_arg, list): - filters.extend(filter_arg) + if filters1 is None: + filter_list = [] + elif not isinstance(filters1, list): + filter_list = [filters1] else: - filters.append(filter_arg) + filter_list = filters1 - return filters + if isinstance(filters2, list): + filter_list.extend(filters2) + elif filters2 is not None: + filter_list.append(filters2) + + return filter_list class Filter(collections.namedtuple("Filter", ['property', 'op', 'value'])): diff --git a/stix2/test/test_datastore.py b/stix2/test/test_datastore.py index e80e8d8..8f40401 100644 --- a/stix2/test/test_datastore.py +++ b/stix2/test/test_datastore.py @@ -4,7 +4,7 @@ from taxii2client import Collection from stix2 import Filter, MemorySink, MemorySource from stix2.datastore import (CompositeDataSource, DataSink, DataSource, make_id, taxii) -from stix2.datastore.filters import apply_common_filters +from stix2.datastore.filters import _assemble_filters, apply_common_filters from stix2.utils import deduplicate COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/' @@ -473,6 +473,15 @@ def test_filters7(): assert len(resp) == 1 +def test_assemble_filters(): + filter1 = Filter("name", "=", "Malicious site hosting downloader") + filter2 = Filter("modified", ">", "2017-01-28T13:49:53.935Z") + result = _assemble_filters(filter1, filter2) + assert len(result) == 2 + assert result[0].property == 'name' + assert result[1].property == 'modified' + + def test_deduplicate(): unique = deduplicate(STIX_OBJS1) diff --git a/stix2/workbench.py b/stix2/workbench.py index 61e99af..9e31b50 100644 --- a/stix2/workbench.py +++ b/stix2/workbench.py @@ -148,7 +148,7 @@ _setup_workbench() # Functions to get all objects of a specific type -def attack_patterns(filters=[]): +def attack_patterns(filters=None): """Retrieve all Attack Pattern objects. Args: @@ -160,7 +160,7 @@ def attack_patterns(filters=[]): return query(filter_list) -def campaigns(filters=[]): +def campaigns(filters=None): """Retrieve all Campaign objects. Args: @@ -172,7 +172,7 @@ def campaigns(filters=[]): return query(filter_list) -def courses_of_action(filters=[]): +def courses_of_action(filters=None): """Retrieve all Course of Action objects. Args: @@ -184,7 +184,7 @@ def courses_of_action(filters=[]): return query(filter_list) -def identities(filters=[]): +def identities(filters=None): """Retrieve all Identity objects. Args: @@ -196,7 +196,7 @@ def identities(filters=[]): return query(filter_list) -def indicators(filters=[]): +def indicators(filters=None): """Retrieve all Indicator objects. Args: @@ -208,7 +208,7 @@ def indicators(filters=[]): return query(filter_list) -def intrusion_sets(filters=[]): +def intrusion_sets(filters=None): """Retrieve all Intrusion Set objects. Args: @@ -220,7 +220,7 @@ def intrusion_sets(filters=[]): return query(filter_list) -def malware(filters=[]): +def malware(filters=None): """Retrieve all Malware objects. Args: @@ -232,7 +232,7 @@ def malware(filters=[]): return query(filter_list) -def observed_data(filters=[]): +def observed_data(filters=None): """Retrieve all Observed Data objects. Args: @@ -244,7 +244,7 @@ def observed_data(filters=[]): return query(filter_list) -def reports(filters=[]): +def reports(filters=None): """Retrieve all Report objects. Args: @@ -256,7 +256,7 @@ def reports(filters=[]): return query(filter_list) -def threat_actors(filters=[]): +def threat_actors(filters=None): """Retrieve all Threat Actor objects. Args: @@ -268,7 +268,7 @@ def threat_actors(filters=[]): return query(filter_list) -def tools(filters=[]): +def tools(filters=None): """Retrieve all Tool objects. Args: @@ -280,7 +280,7 @@ def tools(filters=[]): return query(filter_list) -def vulnerabilities(filters=[]): +def vulnerabilities(filters=None): """Retrieve all Vulnerability objects. Args: