Support MarkingDefinition objs, not just ID strs

(when passed into `add_markings()` or `remove_markings()`)
stix2.1
Chris Lenk 2017-10-02 12:28:47 -04:00
parent e2151659d7
commit 035f1a6c73
5 changed files with 96 additions and 48 deletions

View File

@ -88,6 +88,7 @@ def remove_markings(obj, marking, selectors):
"""
selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors)
granular_markings = obj.get("granular_markings")
@ -97,12 +98,9 @@ def remove_markings(obj, marking, selectors):
granular_markings = utils.expand_markings(granular_markings)
if isinstance(marking, list):
to_remove = []
for m in marking:
to_remove.append({"marking_ref": m, "selectors": selectors})
else:
to_remove = [{"marking_ref": marking, "selectors": selectors}]
to_remove = []
for m in marking:
to_remove.append({"marking_ref": m, "selectors": selectors})
remove = utils.build_granular_marking(to_remove).get("granular_markings")
@ -140,14 +138,12 @@ def add_markings(obj, marking, selectors):
"""
selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors)
if isinstance(marking, list):
granular_marking = []
for m in marking:
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
else:
granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}]
granular_marking = []
for m in marking:
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
if obj.get("granular_markings"):
granular_marking.extend(obj.get("granular_markings"))
@ -244,7 +240,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa
raise TypeError("Required argument 'selectors' must be provided")
selectors = utils.convert_to_list(selectors)
marking = utils.convert_to_list(marking)
marking = utils.convert_to_marking_list(marking)
utils.validate(obj, selectors)
granular_markings = obj.get("granular_markings", [])

View File

@ -31,7 +31,7 @@ def add_markings(obj, marking):
A new version of the given SDO or SRO with specified markings added.
"""
marking = utils.convert_to_list(marking)
marking = utils.convert_to_marking_list(marking)
object_markings = set(obj.get("object_marking_refs", []) + marking)
@ -55,7 +55,7 @@ def remove_markings(obj, marking):
A new version of the given SDO or SRO with specified markings removed.
"""
marking = utils.convert_to_list(marking)
marking = utils.convert_to_marking_list(marking)
object_markings = obj.get("object_marking_refs", [])
@ -121,7 +121,7 @@ def is_marked(obj, marking=None):
provided marking refs match, True is returned.
"""
marking = utils.convert_to_list(marking)
marking = utils.convert_to_marking_list(marking)
object_markings = obj.get("object_marking_refs", [])
if marking:

View File

@ -37,6 +37,12 @@ def _validate_selector(obj, selector):
return True
def _get_marking_id(marking):
if type(marking).__name__ is 'MarkingDefinition': # avoid circular import
return marking.id
return marking
def validate(obj, selectors):
"""Given an SDO or SRO, check that each selector is valid."""
if selectors:
@ -57,6 +63,15 @@ def convert_to_list(data):
return [data]
def convert_to_marking_list(data):
"""Convert input into a list of marking identifiers."""
if data is not None:
if isinstance(data, list):
return [_get_marking_id(x) for x in data]
else:
return [_get_marking_id(data)]
def compress_markings(granular_markings):
"""
Compress granular markings list. If there is more than one marking

View File

@ -1,7 +1,7 @@
import pytest
from stix2 import Malware, markings
from stix2 import TLP_RED, Malware, markings
from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST
from .constants import MARKING_IDS
@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs():
},
],
**MALWARE_KWARGS),
MARKING_IDS[0],
),
(
MALWARE_KWARGS,
@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs():
},
],
**MALWARE_KWARGS),
MARKING_IDS[0],
),
(
Malware(**MALWARE_KWARGS),
Malware(
granular_markings=[
{
"selectors": ["description", "name"],
"marking_ref": TLP_RED.id,
},
],
**MALWARE_KWARGS),
TLP_RED,
),
])
def test_add_marking_mark_multiple_selector_one_refs(data):
before = data[0]
after = data[1]
before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"])
before = markings.add_markings(before, data[2], ["description", "name"])
for m in before["granular_markings"]:
assert m in after["granular_markings"]
@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data):
assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"])
@pytest.mark.parametrize("before", [
Malware(
granular_markings=[
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[0]
},
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[1]
},
],
**MALWARE_KWARGS
@pytest.mark.parametrize("data", [
(
Malware(
granular_markings=[
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[0]
},
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[1]
},
],
**MALWARE_KWARGS
),
[MARKING_IDS[0], MARKING_IDS[1]],
),
dict(
granular_markings=[
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[0]
},
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[1]
},
],
**MALWARE_KWARGS
(
dict(
granular_markings=[
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[0]
},
{
"selectors": ["description"],
"marking_ref": MARKING_IDS[1]
},
],
**MALWARE_KWARGS
),
[MARKING_IDS[0], MARKING_IDS[1]],
),
])
def test_remove_marking_remove_one_selector_with_multiple_refs(before):
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"])
def test_remove_marking_remove_one_selector_with_multiple_refs(data):
before = markings.remove_markings(data[0], data[1], ["description"])
assert "granular_markings" not in before

View File

@ -1,7 +1,7 @@
import pytest
from stix2 import Malware, exceptions, markings
from stix2 import TLP_AMBER, Malware, exceptions, markings
from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS
from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST
@ -21,18 +21,26 @@ MALWARE_KWARGS.update({
Malware(**MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[0]],
**MALWARE_KWARGS),
MARKING_IDS[0],
),
(
MALWARE_KWARGS,
dict(object_marking_refs=[MARKING_IDS[0]],
**MALWARE_KWARGS),
MARKING_IDS[0],
),
(
Malware(**MALWARE_KWARGS),
Malware(object_marking_refs=[TLP_AMBER.id],
**MALWARE_KWARGS),
TLP_AMBER,
),
])
def test_add_markings_one_marking(data):
before = data[0]
after = data[1]
before = markings.add_markings(before, MARKING_IDS[0], None)
before = markings.add_markings(before, data[2], None)
for m in before["object_marking_refs"]:
assert m in after["object_marking_refs"]
@ -280,19 +288,28 @@ def test_remove_markings_object_level(data):
**MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS),
[MARKING_IDS[0], MARKING_IDS[2]],
),
(
dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]],
**MALWARE_KWARGS),
dict(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS),
[MARKING_IDS[0], MARKING_IDS[2]],
),
(
Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id],
**MALWARE_KWARGS),
Malware(object_marking_refs=[MARKING_IDS[1]],
**MALWARE_KWARGS),
[MARKING_IDS[0], TLP_AMBER],
),
])
def test_remove_markings_multiple(data):
before = data[0]
after = data[1]
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None)
before = markings.remove_markings(before, data[2], None)
assert before['object_marking_refs'] == after['object_marking_refs']