Fix bug with mutable default parameter

stix2.0
Chris Lenk 2018-04-05 10:07:35 -04:00
parent 589c00064b
commit e3bbc39353
4 changed files with 39 additions and 22 deletions

View File

@ -359,7 +359,7 @@ class DataSource(with_metaclass(ABCMeta)):
return results return results
def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=[]): def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=None):
"""Retrieve STIX Objects that have a Relationship involving the given """Retrieve STIX Objects that have a Relationship involving the given
STIX object. STIX object.

View File

@ -44,27 +44,35 @@ def _check_filter_components(prop, op, value):
return True return True
def _assemble_filters(filter_arg, filters=[]): def _assemble_filters(filters1=None, filters2=None):
"""Assemble a list of filters. """Assemble a list of filters.
This can be used to allow certain functions to work correctly no matter if This can be used to allow certain functions to work correctly no matter if
the user provides a single filter or a list of them. the user provides a single filter or a list of them.
Args: Args:
filter_arg (Filter or list): The single Filter or list of Filters to be filters1 (Filter or list, optional): The single Filter or list of Filters to
coerced into a list of Filters. coerce into a list of Filters.
filters (list, optional): A list of Filters to be automatically appended. filters2 (Filter or list, optional): The single Filter or list of Filters to
append to the list of Filters.
Returns: Returns:
List of Filters. List of Filters.
""" """
if isinstance(filter_arg, list): if filters1 is None:
filters.extend(filter_arg) filter_list = []
elif not isinstance(filters1, list):
filter_list = [filters1]
else: else:
filters.append(filter_arg) filter_list = filters1
return filters if isinstance(filters2, list):
filter_list.extend(filters2)
elif filters2 is not None:
filter_list.append(filters2)
return filter_list
class Filter(collections.namedtuple("Filter", ['property', 'op', 'value'])): class Filter(collections.namedtuple("Filter", ['property', 'op', 'value'])):

View File

@ -4,7 +4,7 @@ from taxii2client import Collection
from stix2 import Filter, MemorySink, MemorySource from stix2 import Filter, MemorySink, MemorySource
from stix2.datastore import (CompositeDataSource, DataSink, DataSource, from stix2.datastore import (CompositeDataSource, DataSink, DataSource,
make_id, taxii) make_id, taxii)
from stix2.datastore.filters import apply_common_filters from stix2.datastore.filters import _assemble_filters, apply_common_filters
from stix2.utils import deduplicate from stix2.utils import deduplicate
COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/' COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/'
@ -473,6 +473,15 @@ def test_filters7():
assert len(resp) == 1 assert len(resp) == 1
def test_assemble_filters():
filter1 = Filter("name", "=", "Malicious site hosting downloader")
filter2 = Filter("modified", ">", "2017-01-28T13:49:53.935Z")
result = _assemble_filters(filter1, filter2)
assert len(result) == 2
assert result[0].property == 'name'
assert result[1].property == 'modified'
def test_deduplicate(): def test_deduplicate():
unique = deduplicate(STIX_OBJS1) unique = deduplicate(STIX_OBJS1)

View File

@ -148,7 +148,7 @@ _setup_workbench()
# Functions to get all objects of a specific type # Functions to get all objects of a specific type
def attack_patterns(filters=[]): def attack_patterns(filters=None):
"""Retrieve all Attack Pattern objects. """Retrieve all Attack Pattern objects.
Args: Args:
@ -160,7 +160,7 @@ def attack_patterns(filters=[]):
return query(filter_list) return query(filter_list)
def campaigns(filters=[]): def campaigns(filters=None):
"""Retrieve all Campaign objects. """Retrieve all Campaign objects.
Args: Args:
@ -172,7 +172,7 @@ def campaigns(filters=[]):
return query(filter_list) return query(filter_list)
def courses_of_action(filters=[]): def courses_of_action(filters=None):
"""Retrieve all Course of Action objects. """Retrieve all Course of Action objects.
Args: Args:
@ -184,7 +184,7 @@ def courses_of_action(filters=[]):
return query(filter_list) return query(filter_list)
def identities(filters=[]): def identities(filters=None):
"""Retrieve all Identity objects. """Retrieve all Identity objects.
Args: Args:
@ -196,7 +196,7 @@ def identities(filters=[]):
return query(filter_list) return query(filter_list)
def indicators(filters=[]): def indicators(filters=None):
"""Retrieve all Indicator objects. """Retrieve all Indicator objects.
Args: Args:
@ -208,7 +208,7 @@ def indicators(filters=[]):
return query(filter_list) return query(filter_list)
def intrusion_sets(filters=[]): def intrusion_sets(filters=None):
"""Retrieve all Intrusion Set objects. """Retrieve all Intrusion Set objects.
Args: Args:
@ -220,7 +220,7 @@ def intrusion_sets(filters=[]):
return query(filter_list) return query(filter_list)
def malware(filters=[]): def malware(filters=None):
"""Retrieve all Malware objects. """Retrieve all Malware objects.
Args: Args:
@ -232,7 +232,7 @@ def malware(filters=[]):
return query(filter_list) return query(filter_list)
def observed_data(filters=[]): def observed_data(filters=None):
"""Retrieve all Observed Data objects. """Retrieve all Observed Data objects.
Args: Args:
@ -244,7 +244,7 @@ def observed_data(filters=[]):
return query(filter_list) return query(filter_list)
def reports(filters=[]): def reports(filters=None):
"""Retrieve all Report objects. """Retrieve all Report objects.
Args: Args:
@ -256,7 +256,7 @@ def reports(filters=[]):
return query(filter_list) return query(filter_list)
def threat_actors(filters=[]): def threat_actors(filters=None):
"""Retrieve all Threat Actor objects. """Retrieve all Threat Actor objects.
Args: Args:
@ -268,7 +268,7 @@ def threat_actors(filters=[]):
return query(filter_list) return query(filter_list)
def tools(filters=[]): def tools(filters=None):
"""Retrieve all Tool objects. """Retrieve all Tool objects.
Args: Args:
@ -280,7 +280,7 @@ def tools(filters=[]):
return query(filter_list) return query(filter_list)
def vulnerabilities(filters=[]): def vulnerabilities(filters=None):
"""Retrieve all Vulnerability objects. """Retrieve all Vulnerability objects.
Args: Args: