Fix allow_custom in functions calling new_version

stix2.0
Chris Lenk 2018-03-02 11:32:07 -05:00
parent 1eab9b2832
commit 5a71ef2e64
4 changed files with 37 additions and 29 deletions

View File

@ -116,9 +116,9 @@ def remove_markings(obj, marking, selectors):
granular_markings = utils.compress_markings(granular_markings)
if granular_markings:
return new_version(obj, granular_markings=granular_markings)
return new_version(obj, granular_markings=granular_markings, allow_custom=True)
else:
return new_version(obj, granular_markings=None)
return new_version(obj, granular_markings=None, allow_custom=True)
def add_markings(obj, marking, selectors):
@ -152,7 +152,7 @@ def add_markings(obj, marking, selectors):
granular_marking = utils.expand_markings(granular_marking)
granular_marking = utils.compress_markings(granular_marking)
return new_version(obj, granular_markings=granular_marking)
return new_version(obj, granular_markings=granular_marking, allow_custom=True)
def clear_markings(obj, selectors):
@ -207,9 +207,9 @@ def clear_markings(obj, selectors):
granular_markings = utils.compress_markings(granular_markings)
if granular_markings:
return new_version(obj, granular_markings=granular_markings)
return new_version(obj, granular_markings=granular_markings, allow_custom=True)
else:
return new_version(obj, granular_markings=None)
return new_version(obj, granular_markings=None, allow_custom=True)
def is_marked(obj, marking=None, selectors=None, inherited=False, descendants=False):

View File

@ -69,9 +69,9 @@ def remove_markings(obj, marking):
new_markings = [x for x in object_markings if x not in marking]
if new_markings:
return new_version(obj, object_marking_refs=new_markings)
return new_version(obj, object_marking_refs=new_markings, allow_custom=True)
else:
return new_version(obj, object_marking_refs=None)
return new_version(obj, object_marking_refs=None, allow_custom=True)
def set_markings(obj, marking):
@ -103,7 +103,7 @@ def clear_markings(obj):
A new version of the given SDO or SRO with object_marking_refs cleared.
"""
return new_version(obj, object_marking_refs=None)
return new_version(obj, object_marking_refs=None, allow_custom=True)
def is_marked(obj, marking=None):

View File

@ -4,6 +4,13 @@ import stix2
from .constants import FAKE_TIME, MARKING_DEFINITION_ID
IDENTITY_CUSTOM_PROP = stix2.Identity(
name="John Smith",
identity_class="individual",
x_foo="bar",
allow_custom=True,
)
def test_identity_custom_property():
with pytest.raises(ValueError) as excinfo:
@ -82,35 +89,36 @@ def test_parse_identity_custom_property(data):
def test_custom_property_in_bundled_object():
identity = stix2.Identity(
name="John Smith",
identity_class="individual",
x_foo="bar",
allow_custom=True,
)
bundle = stix2.Bundle(identity, allow_custom=True)
bundle = stix2.Bundle(IDENTITY_CUSTOM_PROP, allow_custom=True)
assert bundle.objects[0].x_foo == "bar"
assert '"x_foo": "bar"' in str(bundle)
def test_identity_custom_property_add_marking():
identity = stix2.Identity(
id="identity--311b2d2d-f010-5473-83ec-1edf84858f4c",
created="2015-12-21T19:59:11Z",
modified="2015-12-21T19:59:11Z",
name="John Smith",
identity_class="individual",
x_foo="bar",
allow_custom=True,
)
marking_definition = stix2.MarkingDefinition(
def test_identity_custom_property_revoke():
identity = IDENTITY_CUSTOM_PROP.revoke()
assert identity.x_foo == "bar"
def test_identity_custom_property_edit_markings():
marking_obj = stix2.MarkingDefinition(
id=MARKING_DEFINITION_ID,
definition_type="statement",
definition=stix2.StatementMarking(statement="Copyright 2016, Example Corp")
)
identity2 = identity.add_markings(marking_definition)
assert identity2.x_foo == "bar"
marking_obj2 = stix2.MarkingDefinition(
id=MARKING_DEFINITION_ID,
definition_type="statement",
definition=stix2.StatementMarking(statement="Another one")
)
# None of the following should throw exceptions
identity = IDENTITY_CUSTOM_PROP.add_markings(marking_obj)
identity2 = identity.add_markings(marking_obj2, ['x_foo'])
identity2.remove_markings(marking_obj.id)
identity2.remove_markings(marking_obj2.id, ['x_foo'])
identity2.clear_markings()
identity2.clear_markings('x_foo')
def test_custom_marking_no_init_1():

View File

@ -251,7 +251,7 @@ def revoke(data):
if data.get("revoked"):
raise RevokeError("revoke")
return new_version(data, revoked=True)
return new_version(data, revoked=True, allow_custom=True)
def get_class_hierarchy_names(obj):