From 035f1a6c7354a7a60f9a82eb9792378a37fda346 Mon Sep 17 00:00:00 2001 From: Chris Lenk Date: Mon, 2 Oct 2017 12:28:47 -0400 Subject: [PATCH] Support MarkingDefinition objs, not just ID strs (when passed into `add_markings()` or `remove_markings()`) --- stix2/markings/granular_markings.py | 22 ++++---- stix2/markings/object_markings.py | 6 +-- stix2/markings/utils.py | 15 ++++++ stix2/test/test_granular_markings.py | 78 +++++++++++++++++----------- stix2/test/test_object_markings.py | 23 ++++++-- 5 files changed, 96 insertions(+), 48 deletions(-) diff --git a/stix2/markings/granular_markings.py b/stix2/markings/granular_markings.py index 7e9ccc7..5afd1cc 100644 --- a/stix2/markings/granular_markings.py +++ b/stix2/markings/granular_markings.py @@ -88,6 +88,7 @@ def remove_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings") @@ -97,12 +98,9 @@ def remove_markings(obj, marking, selectors): granular_markings = utils.expand_markings(granular_markings) - if isinstance(marking, list): - to_remove = [] - for m in marking: - to_remove.append({"marking_ref": m, "selectors": selectors}) - else: - to_remove = [{"marking_ref": marking, "selectors": selectors}] + to_remove = [] + for m in marking: + to_remove.append({"marking_ref": m, "selectors": selectors}) remove = utils.build_granular_marking(to_remove).get("granular_markings") @@ -140,14 +138,12 @@ def add_markings(obj, marking, selectors): """ selectors = utils.convert_to_list(selectors) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) - if isinstance(marking, list): - granular_marking = [] - for m in marking: - granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) - else: - granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}] + granular_marking = [] + for m in marking: + granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) if obj.get("granular_markings"): granular_marking.extend(obj.get("granular_markings")) @@ -244,7 +240,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa raise TypeError("Required argument 'selectors' must be provided") selectors = utils.convert_to_list(selectors) - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) utils.validate(obj, selectors) granular_markings = obj.get("granular_markings", []) diff --git a/stix2/markings/object_markings.py b/stix2/markings/object_markings.py index c39c036..a775ddc 100644 --- a/stix2/markings/object_markings.py +++ b/stix2/markings/object_markings.py @@ -31,7 +31,7 @@ def add_markings(obj, marking): A new version of the given SDO or SRO with specified markings added. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = set(obj.get("object_marking_refs", []) + marking) @@ -55,7 +55,7 @@ def remove_markings(obj, marking): A new version of the given SDO or SRO with specified markings removed. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) @@ -121,7 +121,7 @@ def is_marked(obj, marking=None): provided marking refs match, True is returned. """ - marking = utils.convert_to_list(marking) + marking = utils.convert_to_marking_list(marking) object_markings = obj.get("object_marking_refs", []) if marking: diff --git a/stix2/markings/utils.py b/stix2/markings/utils.py index d0d38bb..1154d19 100644 --- a/stix2/markings/utils.py +++ b/stix2/markings/utils.py @@ -37,6 +37,12 @@ def _validate_selector(obj, selector): return True +def _get_marking_id(marking): + if type(marking).__name__ is 'MarkingDefinition': # avoid circular import + return marking.id + return marking + + def validate(obj, selectors): """Given an SDO or SRO, check that each selector is valid.""" if selectors: @@ -57,6 +63,15 @@ def convert_to_list(data): return [data] +def convert_to_marking_list(data): + """Convert input into a list of marking identifiers.""" + if data is not None: + if isinstance(data, list): + return [_get_marking_id(x) for x in data] + else: + return [_get_marking_id(data)] + + def compress_markings(granular_markings): """ Compress granular markings list. If there is more than one marking diff --git a/stix2/test/test_granular_markings.py b/stix2/test/test_granular_markings.py index e910ad3..f8fc803 100644 --- a/stix2/test/test_granular_markings.py +++ b/stix2/test/test_granular_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, markings +from stix2 import TLP_RED, Malware, markings from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST from .constants import MARKING_IDS @@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, @@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs(): }, ], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware( + granular_markings=[ + { + "selectors": ["description", "name"], + "marking_ref": TLP_RED.id, + }, + ], + **MALWARE_KWARGS), + TLP_RED, ), ]) def test_add_marking_mark_multiple_selector_one_refs(data): before = data[0] after = data[1] - before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"]) + before = markings.add_markings(before, data[2], ["description", "name"]) for m in before["granular_markings"]: assert m in after["granular_markings"] @@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data): assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"]) -@pytest.mark.parametrize("before", [ - Malware( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS +@pytest.mark.parametrize("data", [ + ( + Malware( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), - dict( - granular_markings=[ - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[0] - }, - { - "selectors": ["description"], - "marking_ref": MARKING_IDS[1] - }, - ], - **MALWARE_KWARGS + ( + dict( + granular_markings=[ + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[0] + }, + { + "selectors": ["description"], + "marking_ref": MARKING_IDS[1] + }, + ], + **MALWARE_KWARGS + ), + [MARKING_IDS[0], MARKING_IDS[1]], ), ]) -def test_remove_marking_remove_one_selector_with_multiple_refs(before): - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"]) +def test_remove_marking_remove_one_selector_with_multiple_refs(data): + before = markings.remove_markings(data[0], data[1], ["description"]) assert "granular_markings" not in before diff --git a/stix2/test/test_object_markings.py b/stix2/test/test_object_markings.py index 36e8e4d..10949ab 100644 --- a/stix2/test/test_object_markings.py +++ b/stix2/test/test_object_markings.py @@ -1,7 +1,7 @@ import pytest -from stix2 import Malware, exceptions, markings +from stix2 import TLP_AMBER, Malware, exceptions, markings from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST @@ -21,18 +21,26 @@ MALWARE_KWARGS.update({ Malware(**MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], ), ( MALWARE_KWARGS, dict(object_marking_refs=[MARKING_IDS[0]], **MALWARE_KWARGS), + MARKING_IDS[0], + ), + ( + Malware(**MALWARE_KWARGS), + Malware(object_marking_refs=[TLP_AMBER.id], + **MALWARE_KWARGS), + TLP_AMBER, ), ]) def test_add_markings_one_marking(data): before = data[0] after = data[1] - before = markings.add_markings(before, MARKING_IDS[0], None) + before = markings.add_markings(before, data[2], None) for m in before["object_marking_refs"]: assert m in after["object_marking_refs"] @@ -280,19 +288,28 @@ def test_remove_markings_object_level(data): **MALWARE_KWARGS), Malware(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], ), ( dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]], **MALWARE_KWARGS), dict(object_marking_refs=[MARKING_IDS[1]], **MALWARE_KWARGS), + [MARKING_IDS[0], MARKING_IDS[2]], + ), + ( + Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id], + **MALWARE_KWARGS), + Malware(object_marking_refs=[MARKING_IDS[1]], + **MALWARE_KWARGS), + [MARKING_IDS[0], TLP_AMBER], ), ]) def test_remove_markings_multiple(data): before = data[0] after = data[1] - before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None) + before = markings.remove_markings(before, data[2], None) assert before['object_marking_refs'] == after['object_marking_refs']