diff --git a/stix2/test/test_workbench.py b/stix2/test/test_workbench.py index 5e2809b..7f3c9fc 100644 --- a/stix2/test/test_workbench.py +++ b/stix2/test/test_workbench.py @@ -137,3 +137,34 @@ def test_workbench_get_all_vulnerabilities(): resp = vulnerabilities() assert len(resp) == 1 assert resp[0].id == VULNERABILITY_ID + + +def test_workbench_relationships(): + rel = stix2.Relationship(INDICATOR_ID, 'indicates', MALWARE_ID) + add(rel) + + ind = get(INDICATOR_ID) + resp = ind.relationships() + assert len(resp) == 1 + assert resp[0].relationship_type == 'indicates' + assert resp[0].source_ref == INDICATOR_ID + assert resp[0].target_ref == MALWARE_ID + + +def test_workbench_created_by(): + intset = stix2.IntrusionSet(name="Breach 123", created_by_ref=IDENTITY_ID) + add(intset) + creator = intset.created_by() + assert creator.id == IDENTITY_ID + + +def test_workbench_related(): + rel1 = stix2.Relationship(MALWARE_ID, 'targets', IDENTITY_ID) + rel2 = stix2.Relationship(CAMPAIGN_ID, 'uses', MALWARE_ID) + add([rel1, rel2]) + + resp = get(MALWARE_ID).related() + assert len(resp) == 3 + assert any(x['id'] == CAMPAIGN_ID for x in resp) + assert any(x['id'] == INDICATOR_ID for x in resp) + assert any(x['id'] == IDENTITY_ID for x in resp) diff --git a/stix2/workbench.py b/stix2/workbench.py index 55c8009..4069309 100644 --- a/stix2/workbench.py +++ b/stix2/workbench.py @@ -1,6 +1,9 @@ """Functions and class wrappers for interacting with STIX data at a high level. """ +from . import (AttackPattern, Campaign, CourseOfAction, CustomObject, Identity, + Indicator, IntrusionSet, Malware, ObservedData, Report, + ThreatActor, Tool, Vulnerability) from .environment import Environment from .sources.filters import Filter from .sources.memory import MemoryStore @@ -21,6 +24,31 @@ parse = _environ.parse add_data_source = _environ.source.add_data_source +# Wrap SDOs with helper functions + + +def created_by_wrapper(self, *args, **kwargs): + return _environ.creator_of(self, *args, **kwargs) + + +def relationships_wrapper(self, *args, **kwargs): + return _environ.relationships(self, *args, **kwargs) + + +def related_wrapper(self, *args, **kwargs): + return _environ.related_to(self, *args, **kwargs) + + +STIX_OBJS = [AttackPattern, Campaign, CourseOfAction, CustomObject, Identity, + Indicator, IntrusionSet, Malware, ObservedData, Report, + ThreatActor, Tool, Vulnerability] + +for obj_type in STIX_OBJS: + obj_type.created_by = created_by_wrapper + obj_type.relationships = relationships_wrapper + obj_type.related = related_wrapper + + # Functions to get all objects of a specific type