Refactor granular_markings.py code and remove unnecessary code in utils.py

stix2.1
Emmanuelle Vargas-Gonzalez 2017-08-23 13:06:51 -04:00
parent 8687521111
commit 15bff530be
2 changed files with 69 additions and 84 deletions

View File

@ -53,8 +53,8 @@ def set_markings(obj, selectors, marking):
field(s) selected by `selectors`. field(s) selected by `selectors`.
""" """
clear_markings(obj, selectors) obj = clear_markings(obj, selectors)
add_markings(obj, selectors, marking) return add_markings(obj, selectors, marking)
def remove_markings(obj, selectors, marking): def remove_markings(obj, selectors, marking):
@ -76,16 +76,22 @@ def remove_markings(obj, selectors, marking):
selectors = utils.fix_value(selectors) selectors = utils.fix_value(selectors)
utils.validate(obj, selectors, marking) utils.validate(obj, selectors, marking)
utils.expand_markings(obj)
granular_markings = obj.get("granular_markings") granular_markings = obj.get("granular_markings")
if not granular_markings: if not granular_markings:
return return obj
tlo = utils.build_granular_marking( granular_markings = utils.expand_markings(granular_markings)
{"selectors": selectors, "marking_ref": marking}
) if isinstance(marking, list):
to_remove = []
for m in marking:
to_remove.append({"marking_ref": m, "selectors": selectors})
else:
to_remove = [{"marking_ref": marking, "selectors": selectors}]
to_remove = utils.expand_markings(to_remove)
tlo = utils.build_granular_marking(to_remove)
remove = tlo.get("granular_markings", []) remove = tlo.get("granular_markings", [])
@ -93,14 +99,16 @@ def remove_markings(obj, selectors, marking):
raise AssertionError("Unable to remove Granular Marking(s) from" raise AssertionError("Unable to remove Granular Marking(s) from"
" internal collection. Marking(s) not found...") " internal collection. Marking(s) not found...")
obj["granular_markings"] = [ granular_markings = [
m for m in granular_markings if m not in remove m for m in granular_markings if m not in remove
] ]
utils.compress_markings(obj) granular_markings = utils.compress_markings(granular_markings)
if not obj.get("granular_markings"): if not granular_markings:
obj.pop("granular_markings") return obj.new_version(granular_markings=None)
else:
return obj.new_version(granular_markings=granular_markings)
def add_markings(obj, selectors, marking): def add_markings(obj, selectors, marking):
@ -109,10 +117,10 @@ def add_markings(obj, selectors, marking):
Args: Args:
obj: A TLO object. obj: A TLO object.
selectors: string or list of selectors strings relative to the TLO in selectors: list of type string, selectors must be relative to the TLO
which the field(s) appear(s). in which the properties appear.
marking: identifier or list of marking identifiers that apply to the marking: identifier that apply to the properties selected by
field(s) selected by `selectors`. `selectors`.
Raises: Raises:
AssertionError: If `selectors` or `marking` fail data validation. AssertionError: If `selectors` or `marking` fail data validation.
@ -121,15 +129,19 @@ def add_markings(obj, selectors, marking):
selectors = utils.fix_value(selectors) selectors = utils.fix_value(selectors)
utils.validate(obj, selectors, marking) utils.validate(obj, selectors, marking)
granular_marking = {"selectors": sorted(selectors), "marking_ref": marking} if isinstance(marking, list):
granular_marking = []
for m in marking:
granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)})
else:
granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}]
if not obj.get("granular_markings"): if obj.get("granular_markings"):
obj["granular_markings"] = list() granular_marking.extend(obj.get("granular_markings"))
obj["granular_markings"].append(granular_marking) granular_marking = utils.expand_markings(granular_marking)
granular_marking = utils.compress_markings(granular_marking)
utils.expand_markings(obj) return obj.new_version(granular_markings=granular_marking)
utils.compress_markings(obj)
def clear_markings(obj, selectors): def clear_markings(obj, selectors):
@ -149,15 +161,15 @@ def clear_markings(obj, selectors):
selectors = utils.fix_value(selectors) selectors = utils.fix_value(selectors)
utils.validate(obj, selectors) utils.validate(obj, selectors)
utils.expand_markings(obj)
granular_markings = obj.get("granular_markings") granular_markings = obj.get("granular_markings")
if not granular_markings: if not granular_markings:
return return obj
granular_markings = utils.expand_markings(granular_markings)
tlo = utils.build_granular_marking( tlo = utils.build_granular_marking(
{"selectors": selectors, "marking_ref": ["N/A"]} [{"selectors": selectors, "marking_ref": "N/A"}]
) )
clear = tlo.get("granular_markings", []) clear = tlo.get("granular_markings", [])
@ -176,12 +188,14 @@ def clear_markings(obj, selectors):
marking_refs = granular_marking.get("marking_ref") marking_refs = granular_marking.get("marking_ref")
if marking_refs: if marking_refs:
granular_marking["marking_ref"] = list() granular_marking["marking_ref"] = ""
utils.compress_markings(obj) granular_markings = utils.compress_markings(granular_markings)
if not obj.get("granular_markings"): if not granular_markings:
obj.pop("granular_markings") return obj.new_version(granular_markings=None)
else:
return obj.new_version(granular_markings=granular_markings)
def is_marked(obj, selectors, marking=None, inherited=False, descendants=False): def is_marked(obj, selectors, marking=None, inherited=False, descendants=False):

View File

@ -67,81 +67,51 @@ def fix_value(data):
return data return data
def _fix_markings(markings): def compress_markings(granular_markings):
for granular_marking in markings: if not granular_markings:
refs = granular_marking.get("marking_ref", []) return
selectors = granular_marking.get("selectors", [])
if not isinstance(refs, list):
granular_marking["marking_ref"] = [refs]
if not isinstance(selectors, list):
granular_marking["selectors"] = [selectors]
def _group_by(markings):
key = "marking_ref"
retrieve = "selectors"
map_ = collections.defaultdict(set) map_ = collections.defaultdict(set)
for granular_marking in markings: for granular_marking in granular_markings:
for data in granular_marking.get(key, []): if granular_marking.get("marking_ref"):
map_[data].update(granular_marking.get(retrieve)) map_[granular_marking.get("marking_ref")].update(granular_marking.get("selectors"))
granular_markings = \ compressed = \
[ [
{"selectors": sorted(selectors), "marking_ref": ref} {"marking_ref": marking_ref, "selectors": sorted(selectors)}
for ref, selectors in six.iteritems(map_) for marking_ref, selectors in six.iteritems(map_)
] ]
return granular_markings return compressed
def compress_markings(tlo): def expand_markings(granular_markings):
if not tlo.get("granular_markings"): if not granular_markings:
return return
granular_markings = tlo.get("granular_markings") expanded = []
_fix_markings(granular_markings)
tlo["granular_markings"] = _group_by(granular_markings)
def expand_markings(tlo):
if not tlo.get("granular_markings"):
return
granular_markings = tlo.get("granular_markings")
_fix_markings(granular_markings)
expanded = list()
for marking in granular_markings: for marking in granular_markings:
selectors = marking.get("selectors", []) selectors = marking.get("selectors")
marking_ref = marking.get("marking_ref", []) marking_ref = marking.get("marking_ref")
expanded.extend( expanded.extend(
[ [
{"selectors": [sel], "marking_ref": ref} {"marking_ref": marking_ref, "selectors": [selector]}
for sel in selectors for selector in selectors
for ref in marking_ref
] ]
) )
tlo["granular_markings"] = expanded return expanded
def build_granular_marking(granular_marking): def build_granular_marking(granular_marking):
tlo = {"granular_markings": [granular_marking]} tlo = {"granular_markings": granular_marking}
expand_markings(tlo) expand_markings(tlo["granular_markings"])
return tlo return tlo
@ -156,14 +126,15 @@ def iterpath(obj, path=None):
path: None, used recursively to store ancestors. path: None, used recursively to store ancestors.
Example: Example:
>>> for item in iterpath(tlo): >>> for item in iterpath(obj):
>>> print(item) >>> print(item)
(['type'], 'campaign') (['type'], 'campaign')
... ...
(['cybox', 'objects', '[0]', 'hashes', 'sha1'], 'cac35ec206d868b7d7cb0b55f31d9425b075082b') (['cybox', 'objects', '[0]', 'hashes', 'sha1'], 'cac35ec206d868b7d7cb0b55f31d9425b075082b')
Returns: Returns:
tuple: Containing two items: a list of ancestors and the property value. tuple: Containing two items: a list of ancestors and the
property value.
""" """
if path is None: if path is None:
@ -209,7 +180,7 @@ def get_selector(obj, prop):
location is for now the option to assert the data. location is for now the option to assert the data.
Example: Example:
>>> selector = get_selector(tlo, tlo["cybox"]["objects"][0]["file_name"]) >>> selector = get_selector(obj, obj["cybox"]["objects"][0]["file_name"])
>>> print(selector) >>> print(selector)
["cybox.objects.[0].file_name"] ["cybox.objects.[0].file_name"]