diff --git a/stix2/datastore/__init__.py b/stix2/datastore/__init__.py index e482288..f02d773 100644 --- a/stix2/datastore/__init__.py +++ b/stix2/datastore/__init__.py @@ -313,7 +313,7 @@ class DataSource(with_metaclass(ABCMeta)): filter_list = [Filter('type', '=', obj_type)] if filters: if isinstance(filters, list): - filter_list += filters + filter_list.extend(filters) else: filter_list.append(filters) @@ -380,7 +380,7 @@ class DataSource(with_metaclass(ABCMeta)): return results - def related_to(self, obj, relationship_type=None, source_only=False, target_only=False): + def related_to(self, obj, relationship_type=None, source_only=False, target_only=False, filters=None): """Retrieve STIX Objects that have a Relationship involving the given STIX object. @@ -396,6 +396,8 @@ class DataSource(with_metaclass(ABCMeta)): object is the source_ref. Default: False. target_only (bool): Only examine Relationships for which this object is the target_ref. Default: False. + filters (list): list of additional filters the related objects must + match. Returns: list: The STIX objects related to the given STIX object. @@ -416,8 +418,16 @@ class DataSource(with_metaclass(ABCMeta)): ids.update((r.source_ref, r.target_ref)) ids.remove(obj_id) + # Assemble filters + filter_list = [] + if filters: + if isinstance(filters, list): + filter_list.extend(filters) + else: + filter_list.append(filters) + for i in ids: - results.append(self.get(i)) + results.extend(self.query(filter_list + [Filter('id', '=', i)])) return results diff --git a/stix2/test/test_workbench.py b/stix2/test/test_workbench.py index c25cf1a..a8edfbc 100644 --- a/stix2/test/test_workbench.py +++ b/stix2/test/test_workbench.py @@ -181,6 +181,19 @@ def test_workbench_related(): assert len(resp) == 1 +def test_workbench_related_with_filters(): + malware = Malware(labels=["ransomware"], name="CryptorBit", created_by_ref=IDENTITY_ID) + rel = stix2.Relationship(malware.id, 'variant-of', MALWARE_ID) + add([malware, rel]) + + filters = [stix2.Filter('created_by_ref', '=', IDENTITY_ID)] + resp = get(MALWARE_ID).related(filters=filters) + + assert len(resp) == 1 + assert resp[0].name == malware.name + assert resp[0].created_by_ref == IDENTITY_ID + + def test_add_data_source(): fs_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "stix2_data") fs = stix2.FileSystemSource(fs_path)