From 15bff530beadd05b517d4778a832fb3b31d2452c Mon Sep 17 00:00:00 2001 From: Emmanuelle Vargas-Gonzalez Date: Wed, 23 Aug 2017 13:06:51 -0400 Subject: [PATCH] Refactor granular_markings.py code and remove unnecessary code in utils.py --- stix2/markings/granular_markings.py | 76 ++++++++++++++++------------ stix2/markings/utils.py | 77 +++++++++-------------------- 2 files changed, 69 insertions(+), 84 deletions(-) diff --git a/stix2/markings/granular_markings.py b/stix2/markings/granular_markings.py index 903aae9..5c223a3 100644 --- a/stix2/markings/granular_markings.py +++ b/stix2/markings/granular_markings.py @@ -53,8 +53,8 @@ def set_markings(obj, selectors, marking): field(s) selected by `selectors`. """ - clear_markings(obj, selectors) - add_markings(obj, selectors, marking) + obj = clear_markings(obj, selectors) + return add_markings(obj, selectors, marking) def remove_markings(obj, selectors, marking): @@ -76,16 +76,22 @@ def remove_markings(obj, selectors, marking): selectors = utils.fix_value(selectors) utils.validate(obj, selectors, marking) - utils.expand_markings(obj) - granular_markings = obj.get("granular_markings") if not granular_markings: - return + return obj - tlo = utils.build_granular_marking( - {"selectors": selectors, "marking_ref": marking} - ) + granular_markings = utils.expand_markings(granular_markings) + + if isinstance(marking, list): + to_remove = [] + for m in marking: + to_remove.append({"marking_ref": m, "selectors": selectors}) + else: + to_remove = [{"marking_ref": marking, "selectors": selectors}] + + to_remove = utils.expand_markings(to_remove) + tlo = utils.build_granular_marking(to_remove) remove = tlo.get("granular_markings", []) @@ -93,14 +99,16 @@ def remove_markings(obj, selectors, marking): raise AssertionError("Unable to remove Granular Marking(s) from" " internal collection. Marking(s) not found...") - obj["granular_markings"] = [ + granular_markings = [ m for m in granular_markings if m not in remove ] - utils.compress_markings(obj) + granular_markings = utils.compress_markings(granular_markings) - if not obj.get("granular_markings"): - obj.pop("granular_markings") + if not granular_markings: + return obj.new_version(granular_markings=None) + else: + return obj.new_version(granular_markings=granular_markings) def add_markings(obj, selectors, marking): @@ -109,10 +117,10 @@ def add_markings(obj, selectors, marking): Args: obj: A TLO object. - selectors: string or list of selectors strings relative to the TLO in - which the field(s) appear(s). - marking: identifier or list of marking identifiers that apply to the - field(s) selected by `selectors`. + selectors: list of type string, selectors must be relative to the TLO + in which the properties appear. + marking: identifier that apply to the properties selected by + `selectors`. Raises: AssertionError: If `selectors` or `marking` fail data validation. @@ -121,15 +129,19 @@ def add_markings(obj, selectors, marking): selectors = utils.fix_value(selectors) utils.validate(obj, selectors, marking) - granular_marking = {"selectors": sorted(selectors), "marking_ref": marking} + if isinstance(marking, list): + granular_marking = [] + for m in marking: + granular_marking.append({"marking_ref": m, "selectors": sorted(selectors)}) + else: + granular_marking = [{"marking_ref": marking, "selectors": sorted(selectors)}] - if not obj.get("granular_markings"): - obj["granular_markings"] = list() + if obj.get("granular_markings"): + granular_marking.extend(obj.get("granular_markings")) - obj["granular_markings"].append(granular_marking) - - utils.expand_markings(obj) - utils.compress_markings(obj) + granular_marking = utils.expand_markings(granular_marking) + granular_marking = utils.compress_markings(granular_marking) + return obj.new_version(granular_markings=granular_marking) def clear_markings(obj, selectors): @@ -149,15 +161,15 @@ def clear_markings(obj, selectors): selectors = utils.fix_value(selectors) utils.validate(obj, selectors) - utils.expand_markings(obj) - granular_markings = obj.get("granular_markings") if not granular_markings: - return + return obj + + granular_markings = utils.expand_markings(granular_markings) tlo = utils.build_granular_marking( - {"selectors": selectors, "marking_ref": ["N/A"]} + [{"selectors": selectors, "marking_ref": "N/A"}] ) clear = tlo.get("granular_markings", []) @@ -176,12 +188,14 @@ def clear_markings(obj, selectors): marking_refs = granular_marking.get("marking_ref") if marking_refs: - granular_marking["marking_ref"] = list() + granular_marking["marking_ref"] = "" - utils.compress_markings(obj) + granular_markings = utils.compress_markings(granular_markings) - if not obj.get("granular_markings"): - obj.pop("granular_markings") + if not granular_markings: + return obj.new_version(granular_markings=None) + else: + return obj.new_version(granular_markings=granular_markings) def is_marked(obj, selectors, marking=None, inherited=False, descendants=False): diff --git a/stix2/markings/utils.py b/stix2/markings/utils.py index 9286695..adf3069 100644 --- a/stix2/markings/utils.py +++ b/stix2/markings/utils.py @@ -67,81 +67,51 @@ def fix_value(data): return data -def _fix_markings(markings): +def compress_markings(granular_markings): - for granular_marking in markings: - refs = granular_marking.get("marking_ref", []) - selectors = granular_marking.get("selectors", []) - - if not isinstance(refs, list): - granular_marking["marking_ref"] = [refs] - - if not isinstance(selectors, list): - granular_marking["selectors"] = [selectors] - - -def _group_by(markings): - - key = "marking_ref" - retrieve = "selectors" + if not granular_markings: + return map_ = collections.defaultdict(set) - for granular_marking in markings: - for data in granular_marking.get(key, []): - map_[data].update(granular_marking.get(retrieve)) + for granular_marking in granular_markings: + if granular_marking.get("marking_ref"): + map_[granular_marking.get("marking_ref")].update(granular_marking.get("selectors")) - granular_markings = \ + compressed = \ [ - {"selectors": sorted(selectors), "marking_ref": ref} - for ref, selectors in six.iteritems(map_) + {"marking_ref": marking_ref, "selectors": sorted(selectors)} + for marking_ref, selectors in six.iteritems(map_) ] - return granular_markings + return compressed -def compress_markings(tlo): +def expand_markings(granular_markings): - if not tlo.get("granular_markings"): + if not granular_markings: return - granular_markings = tlo.get("granular_markings") - - _fix_markings(granular_markings) - - tlo["granular_markings"] = _group_by(granular_markings) - - -def expand_markings(tlo): - - if not tlo.get("granular_markings"): - return - - granular_markings = tlo.get("granular_markings") - - _fix_markings(granular_markings) - - expanded = list() + expanded = [] for marking in granular_markings: - selectors = marking.get("selectors", []) - marking_ref = marking.get("marking_ref", []) + selectors = marking.get("selectors") + marking_ref = marking.get("marking_ref") expanded.extend( [ - {"selectors": [sel], "marking_ref": ref} - for sel in selectors - for ref in marking_ref + {"marking_ref": marking_ref, "selectors": [selector]} + for selector in selectors ] ) - tlo["granular_markings"] = expanded + return expanded def build_granular_marking(granular_marking): - tlo = {"granular_markings": [granular_marking]} + tlo = {"granular_markings": granular_marking} - expand_markings(tlo) + expand_markings(tlo["granular_markings"]) return tlo @@ -156,14 +126,15 @@ def iterpath(obj, path=None): path: None, used recursively to store ancestors. Example: - >>> for item in iterpath(tlo): + >>> for item in iterpath(obj): >>> print(item) (['type'], 'campaign') ... (['cybox', 'objects', '[0]', 'hashes', 'sha1'], 'cac35ec206d868b7d7cb0b55f31d9425b075082b') Returns: - tuple: Containing two items: a list of ancestors and the property value. + tuple: Containing two items: a list of ancestors and the + property value. """ if path is None: @@ -209,7 +180,7 @@ def get_selector(obj, prop): location is for now the option to assert the data. Example: - >>> selector = get_selector(tlo, tlo["cybox"]["objects"][0]["file_name"]) + >>> selector = get_selector(obj, obj["cybox"]["objects"][0]["file_name"]) >>> print(selector) ["cybox.objects.[0].file_name"]