Support MarkingDefinition objs, not just ID strs
(when passed into `add_markings()` or `remove_markings()`)stix2.1
parent
e2151659d7
commit
035f1a6c73
|
@ -88,6 +88,7 @@ def remove_markings(obj, marking, selectors):
|
|||
|
||||
"""
|
||||
selectors = utils.convert_to_list(selectors)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
utils.validate(obj, selectors)
|
||||
|
||||
granular_markings = obj.get("granular_markings")
|
||||
|
@ -97,12 +98,9 @@ def remove_markings(obj, marking, selectors):
|
|||
|
||||
granular_markings = utils.expand_markings(granular_markings)
|
||||
|
||||
if isinstance(marking, list):
|
||||
to_remove = []
|
||||
for m in marking:
|
||||
to_remove.append({"marking_ref": m, "selectors": selectors})
|
||||
else:
|
||||
to_remove = [{"marking_ref": marking, "selectors": selectors}]
|
||||
to_remove = []
|
||||
for m in marking:
|
||||
to_remove.append({"marking_ref": m, "selectors": selectors})
|
||||
|
||||
remove = utils.build_granular_marking(to_remove).get("granular_markings")
|
||||
|
||||
|
@ -140,14 +138,12 @@ def add_markings(obj, marking, selectors):
|
|||
|
||||
"""
|
||||
selectors = utils.convert_to_list(selectors)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
utils.validate(obj, selectors)
|
||||
|
||||
if isinstance(marking, list):
|
||||
granular_marking = []
|
||||
for m in marking:
|
||||
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
|
||||
else:
|
||||
granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}]
|
||||
granular_marking = []
|
||||
for m in marking:
|
||||
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
|
||||
|
||||
if obj.get("granular_markings"):
|
||||
granular_marking.extend(obj.get("granular_markings"))
|
||||
|
@ -244,7 +240,7 @@ def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=Fa
|
|||
raise TypeError("Required argument 'selectors' must be provided")
|
||||
|
||||
selectors = utils.convert_to_list(selectors)
|
||||
marking = utils.convert_to_list(marking)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
utils.validate(obj, selectors)
|
||||
|
||||
granular_markings = obj.get("granular_markings", [])
|
||||
|
|
|
@ -31,7 +31,7 @@ def add_markings(obj, marking):
|
|||
A new version of the given SDO or SRO with specified markings added.
|
||||
|
||||
"""
|
||||
marking = utils.convert_to_list(marking)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
|
||||
object_markings = set(obj.get("object_marking_refs", []) + marking)
|
||||
|
||||
|
@ -55,7 +55,7 @@ def remove_markings(obj, marking):
|
|||
A new version of the given SDO or SRO with specified markings removed.
|
||||
|
||||
"""
|
||||
marking = utils.convert_to_list(marking)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
|
||||
object_markings = obj.get("object_marking_refs", [])
|
||||
|
||||
|
@ -121,7 +121,7 @@ def is_marked(obj, marking=None):
|
|||
provided marking refs match, True is returned.
|
||||
|
||||
"""
|
||||
marking = utils.convert_to_list(marking)
|
||||
marking = utils.convert_to_marking_list(marking)
|
||||
object_markings = obj.get("object_marking_refs", [])
|
||||
|
||||
if marking:
|
||||
|
|
|
@ -37,6 +37,12 @@ def _validate_selector(obj, selector):
|
|||
return True
|
||||
|
||||
|
||||
def _get_marking_id(marking):
|
||||
if type(marking).__name__ is 'MarkingDefinition': # avoid circular import
|
||||
return marking.id
|
||||
return marking
|
||||
|
||||
|
||||
def validate(obj, selectors):
|
||||
"""Given an SDO or SRO, check that each selector is valid."""
|
||||
if selectors:
|
||||
|
@ -57,6 +63,15 @@ def convert_to_list(data):
|
|||
return [data]
|
||||
|
||||
|
||||
def convert_to_marking_list(data):
|
||||
"""Convert input into a list of marking identifiers."""
|
||||
if data is not None:
|
||||
if isinstance(data, list):
|
||||
return [_get_marking_id(x) for x in data]
|
||||
else:
|
||||
return [_get_marking_id(data)]
|
||||
|
||||
|
||||
def compress_markings(granular_markings):
|
||||
"""
|
||||
Compress granular markings list. If there is more than one marking
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from stix2 import Malware, markings
|
||||
from stix2 import TLP_RED, Malware, markings
|
||||
|
||||
from .constants import MALWARE_MORE_KWARGS as MALWARE_KWARGS_CONST
|
||||
from .constants import MARKING_IDS
|
||||
|
@ -45,6 +45,7 @@ def test_add_marking_mark_one_selector_multiple_refs():
|
|||
},
|
||||
],
|
||||
**MALWARE_KWARGS),
|
||||
MARKING_IDS[0],
|
||||
),
|
||||
(
|
||||
MALWARE_KWARGS,
|
||||
|
@ -56,13 +57,26 @@ def test_add_marking_mark_one_selector_multiple_refs():
|
|||
},
|
||||
],
|
||||
**MALWARE_KWARGS),
|
||||
MARKING_IDS[0],
|
||||
),
|
||||
(
|
||||
Malware(**MALWARE_KWARGS),
|
||||
Malware(
|
||||
granular_markings=[
|
||||
{
|
||||
"selectors": ["description", "name"],
|
||||
"marking_ref": TLP_RED.id,
|
||||
},
|
||||
],
|
||||
**MALWARE_KWARGS),
|
||||
TLP_RED,
|
||||
),
|
||||
])
|
||||
def test_add_marking_mark_multiple_selector_one_refs(data):
|
||||
before = data[0]
|
||||
after = data[1]
|
||||
|
||||
before = markings.add_markings(before, [MARKING_IDS[0]], ["description", "name"])
|
||||
before = markings.add_markings(before, data[2], ["description", "name"])
|
||||
|
||||
for m in before["granular_markings"]:
|
||||
assert m in after["granular_markings"]
|
||||
|
@ -347,36 +361,42 @@ def test_get_markings_positional_arguments_combinations(data):
|
|||
assert set(markings.get_markings(data, "x.z.foo2", False, True)) == set(["10"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("before", [
|
||||
Malware(
|
||||
granular_markings=[
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[0]
|
||||
},
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[1]
|
||||
},
|
||||
],
|
||||
**MALWARE_KWARGS
|
||||
@pytest.mark.parametrize("data", [
|
||||
(
|
||||
Malware(
|
||||
granular_markings=[
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[0]
|
||||
},
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[1]
|
||||
},
|
||||
],
|
||||
**MALWARE_KWARGS
|
||||
),
|
||||
[MARKING_IDS[0], MARKING_IDS[1]],
|
||||
),
|
||||
dict(
|
||||
granular_markings=[
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[0]
|
||||
},
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[1]
|
||||
},
|
||||
],
|
||||
**MALWARE_KWARGS
|
||||
(
|
||||
dict(
|
||||
granular_markings=[
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[0]
|
||||
},
|
||||
{
|
||||
"selectors": ["description"],
|
||||
"marking_ref": MARKING_IDS[1]
|
||||
},
|
||||
],
|
||||
**MALWARE_KWARGS
|
||||
),
|
||||
[MARKING_IDS[0], MARKING_IDS[1]],
|
||||
),
|
||||
])
|
||||
def test_remove_marking_remove_one_selector_with_multiple_refs(before):
|
||||
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[1]], ["description"])
|
||||
def test_remove_marking_remove_one_selector_with_multiple_refs(data):
|
||||
before = markings.remove_markings(data[0], data[1], ["description"])
|
||||
assert "granular_markings" not in before
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from stix2 import Malware, exceptions, markings
|
||||
from stix2 import TLP_AMBER, Malware, exceptions, markings
|
||||
|
||||
from .constants import FAKE_TIME, MALWARE_ID, MARKING_IDS
|
||||
from .constants import MALWARE_KWARGS as MALWARE_KWARGS_CONST
|
||||
|
@ -21,18 +21,26 @@ MALWARE_KWARGS.update({
|
|||
Malware(**MALWARE_KWARGS),
|
||||
Malware(object_marking_refs=[MARKING_IDS[0]],
|
||||
**MALWARE_KWARGS),
|
||||
MARKING_IDS[0],
|
||||
),
|
||||
(
|
||||
MALWARE_KWARGS,
|
||||
dict(object_marking_refs=[MARKING_IDS[0]],
|
||||
**MALWARE_KWARGS),
|
||||
MARKING_IDS[0],
|
||||
),
|
||||
(
|
||||
Malware(**MALWARE_KWARGS),
|
||||
Malware(object_marking_refs=[TLP_AMBER.id],
|
||||
**MALWARE_KWARGS),
|
||||
TLP_AMBER,
|
||||
),
|
||||
])
|
||||
def test_add_markings_one_marking(data):
|
||||
before = data[0]
|
||||
after = data[1]
|
||||
|
||||
before = markings.add_markings(before, MARKING_IDS[0], None)
|
||||
before = markings.add_markings(before, data[2], None)
|
||||
|
||||
for m in before["object_marking_refs"]:
|
||||
assert m in after["object_marking_refs"]
|
||||
|
@ -280,19 +288,28 @@ def test_remove_markings_object_level(data):
|
|||
**MALWARE_KWARGS),
|
||||
Malware(object_marking_refs=[MARKING_IDS[1]],
|
||||
**MALWARE_KWARGS),
|
||||
[MARKING_IDS[0], MARKING_IDS[2]],
|
||||
),
|
||||
(
|
||||
dict(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], MARKING_IDS[2]],
|
||||
**MALWARE_KWARGS),
|
||||
dict(object_marking_refs=[MARKING_IDS[1]],
|
||||
**MALWARE_KWARGS),
|
||||
[MARKING_IDS[0], MARKING_IDS[2]],
|
||||
),
|
||||
(
|
||||
Malware(object_marking_refs=[MARKING_IDS[0], MARKING_IDS[1], TLP_AMBER.id],
|
||||
**MALWARE_KWARGS),
|
||||
Malware(object_marking_refs=[MARKING_IDS[1]],
|
||||
**MALWARE_KWARGS),
|
||||
[MARKING_IDS[0], TLP_AMBER],
|
||||
),
|
||||
])
|
||||
def test_remove_markings_multiple(data):
|
||||
before = data[0]
|
||||
after = data[1]
|
||||
|
||||
before = markings.remove_markings(before, [MARKING_IDS[0], MARKING_IDS[2]], None)
|
||||
before = markings.remove_markings(before, data[2], None)
|
||||
|
||||
assert before['object_marking_refs'] == after['object_marking_refs']
|
||||
|
||||
|
|
Loading…
Reference in New Issue