Fix bug with mutable default parameter

stix2.0
Chris Lenk 2018-04-05 10:07:35 -04:00
parent 589c00064b
commit e3bbc39353
4 changed files with 39 additions and 22 deletions

View File

@ -359,7 +359,7 @@ class DataSource(with_metaclass(ABCMeta)):
return results
def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=[]):
def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=None):
"""Retrieve STIX Objects that have a Relationship involving the given
STIX object.

View File

@ -44,27 +44,35 @@ def _check_filter_components(prop, op, value):
return True
def _assemble_filters(filter_arg, filters=[]):
def _assemble_filters(filters1=None, filters2=None):
"""Assemble a list of filters.
This can be used to allow certain functions to work correctly no matter if
the user provides a single filter or a list of them.
Args:
filter_arg (Filter or list): The single Filter or list of Filters to be
coerced into a list of Filters.
filters (list, optional): A list of Filters to be automatically appended.
filters1 (Filter or list, optional): The single Filter or list of Filters to
coerce into a list of Filters.
filters2 (Filter or list, optional): The single Filter or list of Filters to
append to the list of Filters.
Returns:
List of Filters.
"""
if isinstance(filter_arg, list):
filters.extend(filter_arg)
if filters1 is None:
filter_list = []
elif not isinstance(filters1, list):
filter_list = [filters1]
else:
filters.append(filter_arg)
filter_list = filters1
return filters
if isinstance(filters2, list):
filter_list.extend(filters2)
elif filters2 is not None:
filter_list.append(filters2)
return filter_list
class Filter(collections.namedtuple("Filter", ['property', 'op', 'value'])):

View File

@ -4,7 +4,7 @@ from taxii2client import Collection
from stix2 import Filter, MemorySink, MemorySource
from stix2.datastore import (CompositeDataSource, DataSink, DataSource,
make_id, taxii)
from stix2.datastore.filters import apply_common_filters
from stix2.datastore.filters import _assemble_filters, apply_common_filters
from stix2.utils import deduplicate
COLLECTION_URL = 'https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116/'
@ -473,6 +473,15 @@ def test_filters7():
assert len(resp) == 1
def test_assemble_filters():
filter1 = Filter("name", "=", "Malicious site hosting downloader")
filter2 = Filter("modified", ">", "2017-01-28T13:49:53.935Z")
result = _assemble_filters(filter1, filter2)
assert len(result) == 2
assert result[0].property == 'name'
assert result[1].property == 'modified'
def test_deduplicate():
unique = deduplicate(STIX_OBJS1)

View File

@ -148,7 +148,7 @@ _setup_workbench()
# Functions to get all objects of a specific type
def attack_patterns(filters=[]):
def attack_patterns(filters=None):
"""Retrieve all Attack Pattern objects.
Args:
@ -160,7 +160,7 @@ def attack_patterns(filters=[]):
return query(filter_list)
def campaigns(filters=[]):
def campaigns(filters=None):
"""Retrieve all Campaign objects.
Args:
@ -172,7 +172,7 @@ def campaigns(filters=[]):
return query(filter_list)
def courses_of_action(filters=[]):
def courses_of_action(filters=None):
"""Retrieve all Course of Action objects.
Args:
@ -184,7 +184,7 @@ def courses_of_action(filters=[]):
return query(filter_list)
def identities(filters=[]):
def identities(filters=None):
"""Retrieve all Identity objects.
Args:
@ -196,7 +196,7 @@ def identities(filters=[]):
return query(filter_list)
def indicators(filters=[]):
def indicators(filters=None):
"""Retrieve all Indicator objects.
Args:
@ -208,7 +208,7 @@ def indicators(filters=[]):
return query(filter_list)
def intrusion_sets(filters=[]):
def intrusion_sets(filters=None):
"""Retrieve all Intrusion Set objects.
Args:
@ -220,7 +220,7 @@ def intrusion_sets(filters=[]):
return query(filter_list)
def malware(filters=[]):
def malware(filters=None):
"""Retrieve all Malware objects.
Args:
@ -232,7 +232,7 @@ def malware(filters=[]):
return query(filter_list)
def observed_data(filters=[]):
def observed_data(filters=None):
"""Retrieve all Observed Data objects.
Args:
@ -244,7 +244,7 @@ def observed_data(filters=[]):
return query(filter_list)
def reports(filters=[]):
def reports(filters=None):
"""Retrieve all Report objects.
Args:
@ -256,7 +256,7 @@ def reports(filters=[]):
return query(filter_list)
def threat_actors(filters=[]):
def threat_actors(filters=None):
"""Retrieve all Threat Actor objects.
Args:
@ -268,7 +268,7 @@ def threat_actors(filters=[]):
return query(filter_list)
def tools(filters=[]):
def tools(filters=None):
"""Retrieve all Tool objects.
Args:
@ -280,7 +280,7 @@ def tools(filters=[]):
return query(filter_list)
def vulnerabilities(filters=[]):
def vulnerabilities(filters=None):
"""Retrieve all Vulnerability objects.
Args: