diff --git a/stix2/__init__.py b/stix2/__init__.py index 770f8cb..55911a4 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -25,6 +25,8 @@ from .common import (TLP_AMBER, TLP_GREEN, TLP_RED, TLP_WHITE, CustomMarking, MarkingDefinition, StatementMarking, TLPMarking) from .core import Bundle, _register_type, parse from .environment import Environment, ObjectFactory +from .markings import (add_markings, clear_markings, get_markings, is_marked, + remove_markings, set_markings) from .observables import (URL, AlternateDataStream, ArchiveExt, Artifact, AutonomousSystem, CustomExtension, CustomObservable, Directory, DomainName, EmailAddress, EmailMessage, diff --git a/stix2/common.py b/stix2/common.py index bd177fd..602e43f 100644 --- a/stix2/common.py +++ b/stix2/common.py @@ -3,6 +3,7 @@ from collections import OrderedDict from .base import _STIXBase +from .markings import MarkingsMixin from .properties import (HashesProperty, IDProperty, ListProperty, Property, ReferenceProperty, SelectorProperty, StringProperty, TimestampProperty, TypeProperty) @@ -76,7 +77,7 @@ class MarkingProperty(Property): raise ValueError("must be a Statement, TLP Marking or a registered marking.") -class MarkingDefinition(_STIXBase): +class MarkingDefinition(_STIXBase, MarkingsMixin): _type = 'marking-definition' _properties = OrderedDict() _properties.update([ diff --git a/stix2/markings/__init__.py b/stix2/markings/__init__.py index 9194ee1..69f3ace 100644 --- a/stix2/markings/__init__.py +++ b/stix2/markings/__init__.py @@ -224,3 +224,16 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa result = result or object_markings.is_marked(obj, object_marks) return result + + +class MarkingsMixin(): + pass + + +# Note that all of these methods will return a new object because of immutability +MarkingsMixin.get_markings = get_markings +MarkingsMixin.set_markings = set_markings +MarkingsMixin.remove_markings = remove_markings +MarkingsMixin.add_markings = add_markings +MarkingsMixin.clear_markings = clear_markings +MarkingsMixin.is_marked = is_marked diff --git a/stix2/markings/granular_markings.py b/stix2/markings/granular_markings.py index 893428e..3fe3a48 100644 --- a/stix2/markings/granular_markings.py +++ b/stix2/markings/granular_markings.py @@ -90,6 +90,7 @@ def remove_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings") @@ -99,12 +100,9 @@ def remove_markings(obj, marking, selectors): granular_markings = utils.expand_markings(granular_markings) - if isinstance(marking, list): - to_remove = [] - for m in marking: - to_remove.append({"marking_ref": m, "selectors": selectors}) - else: - to_remove = [{"marking_ref": marking, "selectors": selectors}] + to_remove = [] + for m in marking: + to_remove.append({"marking_ref": m, "selectors": selectors}) remove = utils.build_granular_marking(to_remove).get("granular_markings") @@ -142,14 +140,12 @@ def add_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) - if isinstance(marking, list): - granular_marking = [] - for m in marking: - granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) - else: - granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}] + granular_marking = [] + for m in marking: + granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) if obj.get("granular_markings"): granular_marking.extend(obj.get("granular_markings")) @@ -246,7 +242,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa raise TypeError("Required argument 'selectors' must be provided") selectors = utils.convert_to_list(selectors) - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings", []) diff --git a/stix2/markings/object_markings.py b/stix2/markings/object_markings.py index ce77966..cb78294 100644 --- a/stix2/markings/object_markings.py +++ b/stix2/markings/object_markings.py @@ -33,7 +33,7 @@ def add_markings(obj, marking): A new version of the given SDO or SRO with specified markings added. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = set(obj.get("object_marking_refs", []) + marking) @@ -57,7 +57,7 @@ def remove_markings(obj, marking): A new version of the given SDO or SRO with specified markings removed. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) @@ -123,7 +123,7 @@ def is_marked(obj, marking=None): provided marking refs match, True is returned. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) if marking: diff --git a/stix2/markings/utils.py b/stix2/markings/utils.py index 6fed976..3024fe8 100644 --- a/stix2/markings/utils.py +++ b/stix2/markings/utils.py @@ -39,6 +39,12 @@ def _validate_selector(obj, selector): return True +def _get_marking_id(marking): + if type(marking).__name__ is 'MarkingDefinition': # avoid circular import + return marking.id + return marking + + def validate(obj, selectors): """Given an SDO or SRO, check that each selector is valid.""" if selectors: @@ -59,6 +65,15 @@ def convert_to_list(data): return [data] +def convert_to_marking_list(data): + """Convert input into a list of marking identifiers.""" + if data is not None: + if isinstance(data, list): + return [_get_marking_id(x) for x in data] + else: + return [_get_marking_id(data)] + + def compress_markings(granular_markings): """Compress granular markings list. diff --git a/stix2/sdo.py b/stix2/sdo.py index cdbb9ca..c3eb619 100644 --- a/stix2/sdo.py +++ b/stix2/sdo.py @@ -6,6 +6,7 @@ import stix2 from .base import _STIXBase from .common import ExternalReference, GranularMarking, KillChainPhase +from .markings import MarkingsMixin from .observables import ObservableProperty from .properties import (BooleanProperty, IDProperty, IntegerProperty, ListProperty, PatternProperty, ReferenceProperty, @@ -13,7 +14,11 @@ from .properties import (BooleanProperty, IDProperty, IntegerProperty, from .utils import NOW -class AttackPattern(_STIXBase): +class STIXDomainObject(_STIXBase, MarkingsMixin): + pass + + +class AttackPattern(STIXDomainObject): _type = 'attack-pattern' _properties = OrderedDict() @@ -34,7 +39,7 @@ class AttackPattern(_STIXBase): ]) -class Campaign(_STIXBase): +class Campaign(STIXDomainObject): _type = 'campaign' _properties = OrderedDict() @@ -58,7 +63,7 @@ class Campaign(_STIXBase): ]) -class CourseOfAction(_STIXBase): +class CourseOfAction(STIXDomainObject): _type = 'course-of-action' _properties = OrderedDict() @@ -78,7 +83,7 @@ class CourseOfAction(_STIXBase): ]) -class Identity(_STIXBase): +class Identity(STIXDomainObject): _type = 'identity' _properties = OrderedDict() @@ -101,7 +106,7 @@ class Identity(_STIXBase): ]) -class Indicator(_STIXBase): +class Indicator(STIXDomainObject): _type = 'indicator' _properties = OrderedDict() @@ -125,7 +130,7 @@ class Indicator(_STIXBase): ]) -class IntrusionSet(_STIXBase): +class IntrusionSet(STIXDomainObject): _type = 'intrusion-set' _properties = OrderedDict() @@ -152,7 +157,7 @@ class IntrusionSet(_STIXBase): ]) -class Malware(_STIXBase): +class Malware(STIXDomainObject): _type = 'malware' _properties = OrderedDict() @@ -173,7 +178,7 @@ class Malware(_STIXBase): ]) -class ObservedData(_STIXBase): +class ObservedData(STIXDomainObject): _type = 'observed-data' _properties = OrderedDict() @@ -195,7 +200,7 @@ class ObservedData(_STIXBase): ]) -class Report(_STIXBase): +class Report(STIXDomainObject): _type = 'report' _properties = OrderedDict() @@ -217,7 +222,7 @@ class Report(_STIXBase): ]) -class ThreatActor(_STIXBase): +class ThreatActor(STIXDomainObject): _type = 'threat-actor' _properties = OrderedDict() @@ -245,7 +250,7 @@ class ThreatActor(_STIXBase): ]) -class Tool(_STIXBase): +class Tool(STIXDomainObject): _type = 'tool' _properties = OrderedDict() @@ -267,7 +272,7 @@ class Tool(_STIXBase): ]) -class Vulnerability(_STIXBase): +class Vulnerability(STIXDomainObject): _type = 'vulnerability' _properties = OrderedDict() @@ -314,7 +319,7 @@ def CustomObject(type='x-custom-type', properties=None): def custom_builder(cls): - class _Custom(cls, _STIXBase): + class _Custom(cls, STIXDomainObject): _type = type _properties = OrderedDict() _properties.update([ diff --git a/stix2/sro.py b/stix2/sro.py index af483bc..4fa0465 100644 --- a/stix2/sro.py +++ b/stix2/sro.py @@ -4,13 +4,18 @@ from collections import OrderedDict from .base import _STIXBase from .common import ExternalReference, GranularMarking +from .markings import MarkingsMixin from .properties import (BooleanProperty, IDProperty, IntegerProperty, ListProperty, ReferenceProperty, StringProperty, TimestampProperty, TypeProperty) from .utils import NOW -class Relationship(_STIXBase): +class STIXRelationshipObject(_STIXBase, MarkingsMixin): + pass + + +class Relationship(STIXRelationshipObject): _type = 'relationship' _properties = OrderedDict() @@ -45,7 +50,7 @@ class Relationship(_STIXBase): super(Relationship, self).__init__(**kwargs) -class Sighting(_STIXBase): +class Sighting(STIXRelationshipObject): _type = 'sighting' _properties = OrderedDict() _properties.update([ diff --git a/stix2/test/test_granular_markings.py b/stix2/test/test_granular_markings.py index e910ad3..f8fc803 100644 --- a/stix2/test/test_granular_markings.py +++ b/stix2/test/test_granular_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, markings +from stix2 import TLP_RED, Malware, markings from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST from .constants import MARKING_IDS @@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, @@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware( + granular_markings=[ + { + "selectors": ["description", "name"], + "marking_ref": TLP_RED.id, + }, + ], + **MALWARE_KWARGS), + TLP_RED, ), ]) def test_add_marking_mark_multiple_selector_one_refs(data): before = data[0] after = data[1] - before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"]) + before = markings.add_markings(before, data[2], ["description", "name"]) for m in before["granular_markings"]: assert m in after["granular_markings"] @@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data): assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"]) -@pytest.mark.parametrize("before", [ - Malware( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS +@pytest.mark.parametrize("data", [ + ( + Malware( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), - dict( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS + ( + dict( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), ]) -def test_remove_marking_remove_one_selector_with_multiple_refs(before): - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) +def test_remove_marking_remove_one_selector_with_multiple_refs(data): + before = markings.remove_markings(data[0], data[1], ["description"]) assert "granular_markings" not in before diff --git a/stix2/test/test_markings.py b/stix2/test/test_markings.py index 0c6069a..456bf92 100644 --- a/stix2/test/test_markings.py +++ b/stix2/test/test_markings.py @@ -241,4 +241,14 @@ def test_marking_wrong_type_construction(): assert str(excinfo.value) == "Must supply a list, containing tuples. For example, [('property1', IntegerProperty())]" -# TODO: Add other examples +def test_campaign_add_markings(): + campaign = stix2.Campaign( + id="campaign--8e2e2d2b-17d4-4cbf-938f-98ee46b3cd3f", + created_by_ref="identity--f431f809-377b-45e0-aa1c-6a4751cae5ff", + created="2016-04-06T20:03:00Z", + modified="2016-04-06T20:03:00Z", + name="Green Group Attacks Against Finance", + description="Campaign by Green Group against a series of targets in the financial services sector.", + ) + campaign = campaign.add_markings(TLP_WHITE) + assert campaign.object_marking_refs[0] == TLP_WHITE.id diff --git a/stix2/test/test_object_markings.py b/stix2/test/test_object_markings.py index 36e8e4d..10949ab 100644 --- a/stix2/test/test_object_markings.py +++ b/stix2/test/test_object_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, exceptions, markings +from stix2 import TLP_AMBER, Malware, exceptions, markings from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST @@ -21,18 +21,26 @@ MALWARE_KWARGS.update({ Malware(**MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, dict(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware(object_marking_refs=[TLP_AMBER.id], + **MALWARE_KWARGS), + TLP_AMBER, ), ]) def test_add_markings_one_marking(data): before = data[0] after = data[1] - before = markings.add_markings(before, MARKING_IDS[0], None) + before = markings.add_markings(before, data[2], None) for m in before["object_marking_refs"]: assert m in after["object_marking_refs"] @@ -280,19 +288,28 @@ def test_remove_markings_object_level(data): **MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], ), ( dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], **MALWARE_KWARGS), dict(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], + ), + ( + Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id], + **MALWARE_KWARGS), + Malware(object_marking_refs=[MARKING_IDS[1]], + **MALWARE_KWARGS), + [MARKING_IDS[0], TLP_AMBER], ), ]) def test_remove_markings_multiple(data): before = data[0] after = data[1] - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None) + before = markings.remove_markings(before, data[2], None) assert before['object_marking_refs'] == after['object_marking_refs']