Allow passing add'l filters to related_to()

stix2.0
Chris Lenk 2018-03-19 15:56:20 -04:00
parent 61733ad899
commit 4fb24f14de
2 changed files with 26 additions and 3 deletions

View File

@ -313,7 +313,7 @@ class DataSource(with_metaclass(ABCMeta)):
filter_list = [Filter('type', '=', obj_type)]
if filters:
if isinstance(filters, list):
filter_list += filters
filter_list.extend(filters)
else:
filter_list.append(filters)
@ -380,7 +380,7 @@ class DataSource(with_metaclass(ABCMeta)):
return results
def related_to(self, obj, relationship_type=None, source_only=False, target_only=False):
def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=None):
"""Retrieve STIX Objects that have a Relationship involving the given
STIX object.
@ -396,6 +396,8 @@ class DataSource(with_metaclass(ABCMeta)):
object is the source_ref. Default: False.
target_only (bool): Only examine Relationships for which this
object is the target_ref. Default: False.
filters (list): list of additional filters the related objects must
match.
Returns:
list: The STIX objects related to the given STIX object.
@ -416,8 +418,16 @@ class DataSource(with_metaclass(ABCMeta)):
ids.update((r.source_ref, r.target_ref))
ids.remove(obj_id)
# Assemble filters
filter_list = []
if filters:
if isinstance(filters, list):
filter_list.extend(filters)
else:
filter_list.append(filters)
for i in ids:
results.append(self.get(i))
results.extend(self.query(filter_list + [Filter('id', '=', i)]))
return results

View File

@ -181,6 +181,19 @@ def test_workbench_related():
assert len(resp) == 1
def test_workbench_related_with_filters():
malware = Malware(labels=["ransomware"], name="CryptorBit", created_by_ref=IDENTITY_ID)
rel = stix2.Relationship(malware.id, 'variant-of', MALWARE_ID)
add([malware, rel])
filters = [stix2.Filter('created_by_ref', '=', IDENTITY_ID)]
resp = get(MALWARE_ID).related(filters=filters)
assert len(resp) == 1
assert resp[0].name == malware.name
assert resp[0].created_by_ref == IDENTITY_ID
def test_add_data_source():
fs_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "stix2_data")
fs = stix2.FileSystemSource(fs_path)