Refactor granular_markings.py code and remove unnecessary code in utils.py

stix2.1
Emmanuelle Vargas-Gonzalez 2017-08-23 13:06:51 -04:00
parent 8687521111
commit 15bff530be
2 changed files with 69 additions and 84 deletions

View File

@ -53,8 +53,8 @@ def set_markings(obj, selectors, marking):
field(s) selected by `selectors`.
"""
clear_markings(obj, selectors)
add_markings(obj, selectors, marking)
obj = clear_markings(obj, selectors)
return add_markings(obj, selectors, marking)
def remove_markings(obj, selectors, marking):
@ -76,16 +76,22 @@ def remove_markings(obj, selectors, marking):
selectors = utils.fix_value(selectors)
utils.validate(obj, selectors, marking)
utils.expand_markings(obj)
granular_markings = obj.get("granular_markings")
if not granular_markings:
return
return obj
tlo = utils.build_granular_marking(
{"selectors": selectors, "marking_ref": marking}
)
granular_markings = utils.expand_markings(granular_markings)
if isinstance(marking, list):
to_remove = []
for m in marking:
to_remove.append({"marking_ref": m, "selectors": selectors})
else:
to_remove = [{"marking_ref": marking, "selectors": selectors}]
to_remove = utils.expand_markings(to_remove)
tlo = utils.build_granular_marking(to_remove)
remove = tlo.get("granular_markings", [])
@ -93,14 +99,16 @@ def remove_markings(obj, selectors, marking):
raise AssertionError("Unable to remove Granular Marking(s) from"
" internal collection. Marking(s) not found...")
obj["granular_markings"] = [
granular_markings = [
m for m in granular_markings if m not in remove
]
utils.compress_markings(obj)
granular_markings = utils.compress_markings(granular_markings)
if not obj.get("granular_markings"):
obj.pop("granular_markings")
if not granular_markings:
return obj.new_version(granular_markings=None)
else:
return obj.new_version(granular_markings=granular_markings)
def add_markings(obj, selectors, marking):
@ -109,10 +117,10 @@ def add_markings(obj, selectors, marking):
Args:
obj: A TLO object.
selectors: string or list of selectors strings relative to the TLO in
which the field(s) appear(s).
marking: identifier or list of marking identifiers that apply to the
field(s) selected by `selectors`.
selectors: list of type string, selectors must be relative to the TLO
in which the properties appear.
marking: identifier that apply to the properties selected by
`selectors`.
Raises:
AssertionError: If `selectors` or `marking` fail data validation.
@ -121,15 +129,19 @@ def add_markings(obj, selectors, marking):
selectors = utils.fix_value(selectors)
utils.validate(obj, selectors, marking)
granular_marking = {"selectors": sorted(selectors), "marking_ref": marking}
if isinstance(marking, list):
granular_marking = []
for m in marking:
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
else:
granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}]
if not obj.get("granular_markings"):
obj["granular_markings"] = list()
if obj.get("granular_markings"):
granular_marking.extend(obj.get("granular_markings"))
obj["granular_markings"].append(granular_marking)
utils.expand_markings(obj)
utils.compress_markings(obj)
granular_marking = utils.expand_markings(granular_marking)
granular_marking = utils.compress_markings(granular_marking)
return obj.new_version(granular_markings=granular_marking)
def clear_markings(obj, selectors):
@ -149,15 +161,15 @@ def clear_markings(obj, selectors):
selectors = utils.fix_value(selectors)
utils.validate(obj, selectors)
utils.expand_markings(obj)
granular_markings = obj.get("granular_markings")
if not granular_markings:
return
return obj
granular_markings = utils.expand_markings(granular_markings)
tlo = utils.build_granular_marking(
{"selectors": selectors, "marking_ref": ["N/A"]}
[{"selectors": selectors, "marking_ref": "N/A"}]
)
clear = tlo.get("granular_markings", [])
@ -176,12 +188,14 @@ def clear_markings(obj, selectors):
marking_refs = granular_marking.get("marking_ref")
if marking_refs:
granular_marking["marking_ref"] = list()
granular_marking["marking_ref"] = ""
utils.compress_markings(obj)
granular_markings = utils.compress_markings(granular_markings)
if not obj.get("granular_markings"):
obj.pop("granular_markings")
if not granular_markings:
return obj.new_version(granular_markings=None)
else:
return obj.new_version(granular_markings=granular_markings)
def is_marked(obj, selectors, marking=None, inherited=False, descendants=False):

View File

@ -67,81 +67,51 @@ def fix_value(data):
return data
def _fix_markings(markings):
def compress_markings(granular_markings):
for granular_marking in markings:
refs = granular_marking.get("marking_ref", [])
selectors = granular_marking.get("selectors", [])
if not isinstance(refs, list):
granular_marking["marking_ref"] = [refs]
if not isinstance(selectors, list):
granular_marking["selectors"] = [selectors]
def _group_by(markings):
key = "marking_ref"
retrieve = "selectors"
if not granular_markings:
return
map_ = collections.defaultdict(set)
for granular_marking in markings:
for data in granular_marking.get(key, []):
map_[data].update(granular_marking.get(retrieve))
for granular_marking in granular_markings:
if granular_marking.get("marking_ref"):
map_[granular_marking.get("marking_ref")].update(granular_marking.get("selectors"))
granular_markings = \
compressed = \
[
{"selectors": sorted(selectors), "marking_ref": ref}
for ref, selectors in six.iteritems(map_)
{"marking_ref": marking_ref, "selectors": sorted(selectors)}
for marking_ref, selectors in six.iteritems(map_)
]
return granular_markings
return compressed
def compress_markings(tlo):
def expand_markings(granular_markings):
if not tlo.get("granular_markings"):
if not granular_markings:
return
granular_markings = tlo.get("granular_markings")
_fix_markings(granular_markings)
tlo["granular_markings"] = _group_by(granular_markings)
def expand_markings(tlo):
if not tlo.get("granular_markings"):
return
granular_markings = tlo.get("granular_markings")
_fix_markings(granular_markings)
expanded = list()
expanded = []
for marking in granular_markings:
selectors = marking.get("selectors", [])
marking_ref = marking.get("marking_ref", [])
selectors = marking.get("selectors")
marking_ref = marking.get("marking_ref")
expanded.extend(
[
{"selectors": [sel], "marking_ref": ref}
for sel in selectors
for ref in marking_ref
{"marking_ref": marking_ref, "selectors": [selector]}
for selector in selectors
]
)
tlo["granular_markings"] = expanded
return expanded
def build_granular_marking(granular_marking):
tlo = {"granular_markings": [granular_marking]}
tlo = {"granular_markings": granular_marking}
expand_markings(tlo)
expand_markings(tlo["granular_markings"])
return tlo
@ -156,14 +126,15 @@ def iterpath(obj, path=None):
path: None, used recursively to store ancestors.
Example:
>>> for item in iterpath(tlo):
>>> for item in iterpath(obj):
>>> print(item)
(['type'], 'campaign')
...
(['cybox', 'objects', '[0]', 'hashes', 'sha1'], 'cac35ec206d868b7d7cb0b55f31d9425b075082b')
Returns:
tuple: Containing two items: a list of ancestors and the property value.
tuple: Containing two items: a list of ancestors and the
property value.
"""
if path is None:
@ -209,7 +180,7 @@ def get_selector(obj, prop):
location is for now the option to assert the data.
Example:
>>> selector = get_selector(tlo, tlo["cybox"]["objects"][0]["file_name"])
>>> selector = get_selector(obj, obj["cybox"]["objects"][0]["file_name"])
>>> print(selector)
["cybox.objects.[0].file_name"]