diff --git a/stix2/equivalence/patterns/transform/comparison.py b/stix2/equivalence/patterns/transform/comparison.py index 35cd8a8..2848598 100644 --- a/stix2/equivalence/patterns/transform/comparison.py +++ b/stix2/equivalence/patterns/transform/comparison.py @@ -4,6 +4,9 @@ Transformation utilities for STIX pattern comparison expressions. import functools import itertools from stix2.equivalence.patterns.transform import Transformer +from stix2.equivalence.patterns.transform.specials import ( + windows_reg_key, ipv4_addr, ipv6_addr +) from stix2.patterns import ( _BooleanExpression, _ComparisonExpression, AndBooleanExpression, OrBooleanExpression, ParentheticalExpression @@ -57,6 +60,7 @@ class ComparisonExpressionTransformer(Transformer): Specifically, subclasses can implement methods: "transform_or" for OR nodes "transform_and" for AND nodes + "transform_comparison" for plain comparison nodes ( ) "transform_default" for both types of nodes "transform_default" is a fallback, if a type-specific callback is not @@ -69,8 +73,7 @@ class ComparisonExpressionTransformer(Transformer): method: a 2-tuple with the transformed AST and a boolean for change detection. See doc for the superclass' method. - This process currently silently drops parenthetical nodes, and "leaf" - comparison expression nodes are left unchanged. + This process currently silently drops parenthetical nodes. """ def transform(self, ast): @@ -88,9 +91,7 @@ class ComparisonExpressionTransformer(Transformer): changed = True elif isinstance(ast, _ComparisonExpression): - # Terminates recursion; we don't change these nodes - result = ast - changed = False + result, changed = self.__dispatch_transform(ast) elif isinstance(ast, ParentheticalExpression): # Drop these @@ -116,6 +117,11 @@ class ComparisonExpressionTransformer(Transformer): elif isinstance(ast, OrBooleanExpression): meth = getattr(self, "transform_or", self.transform_default) + elif isinstance(ast, _ComparisonExpression): + meth = getattr( + self, "transform_comparison", self.transform_default + ) + else: meth = self.transform_default @@ -142,7 +148,7 @@ class OrderDedupeTransformer( A or A => A """ - def transform_default(self, ast): + def __transform(self, ast): """ Sort/dedupe children. AND and OR can be treated identically. @@ -172,6 +178,12 @@ class OrderDedupeTransformer( return ast, changed + def transform_or(self, ast): + return self.__transform(ast) + + def transform_and(self, ast): + return self.__transform(ast) + class FlattenTransformer(ComparisonExpressionTransformer): """ @@ -182,7 +194,7 @@ class FlattenTransformer(ComparisonExpressionTransformer): (A) => A """ - def transform_default(self, ast): + def __transform(self, ast): """ Flatten children. AND and OR can be treated mostly identically. The little difference is that we can absorb AND children if we're an AND @@ -192,14 +204,14 @@ class FlattenTransformer(ComparisonExpressionTransformer): :return: The same AST node, but with flattened children """ - if isinstance(ast, _BooleanExpression) and len(ast.operands) == 1: + changed = False + if len(ast.operands) == 1: # Replace an AND/OR with one child, with the child itself. ast = ast.operands[0] changed = True else: flat_operands = [] - changed = False for operand in ast.operands: if isinstance(operand, _BooleanExpression) \ and ast.operator == operand.operator: @@ -213,6 +225,12 @@ class FlattenTransformer(ComparisonExpressionTransformer): return ast, changed + def transform_or(self, ast): + return self.__transform(ast) + + def transform_and(self, ast): + return self.__transform(ast) + class AbsorptionTransformer( ComparisonExpressionTransformer @@ -224,57 +242,62 @@ class AbsorptionTransformer( A or (A and B) = A """ - def transform_default(self, ast): + def __transform(self, ast): changed = False - if isinstance(ast, _BooleanExpression): - secondary_op = "AND" if ast.operator == "OR" else "OR" + secondary_op = "AND" if ast.operator == "OR" else "OR" - to_delete = set() + to_delete = set() - # Check i (child1) against j to see if we can delete j. - for i, child1 in enumerate(ast.operands): - if i in to_delete: + # Check i (child1) against j to see if we can delete j. + for i, child1 in enumerate(ast.operands): + if i in to_delete: + continue + + for j, child2 in enumerate(ast.operands): + if i == j or j in to_delete: continue - for j, child2 in enumerate(ast.operands): - if i == j or j in to_delete: - continue + # We're checking if child1 is contained in child2, so + # child2 has to be a compound object, not just a simple + # comparison expression. We also require the right operator + # for child2: "AND" if ast is "OR" and vice versa. + if not isinstance(child2, _BooleanExpression) \ + or child2.operator != secondary_op: + continue - # We're checking if child1 is contained in child2, so - # child2 has to be a compound object, not just a simple - # comparison expression. We also require the right operator - # for child2: "AND" if ast is "OR" and vice versa. - if not isinstance(child2, _BooleanExpression) \ - or child2.operator != secondary_op: - continue + # The simple check: is child1 contained in child2? + if iter_in( + child1, child2.operands, comparison_expression_cmp + ): + to_delete.add(j) - # The simple check: is child1 contained in child2? - if iter_in( - child1, child2.operands, comparison_expression_cmp + # A more complicated check: does child1 occur in child2 + # in a "flattened" form? + elif child1.operator == child2.operator: + if all( + iter_in( + child1_operand, child2.operands, + comparison_expression_cmp + ) + for child1_operand in child1.operands ): to_delete.add(j) - # A more complicated check: does child1 occur in child2 - # in a "flattened" form? - elif child1.operator == child2.operator: - if all( - iter_in( - child1_operand, child2.operands, - comparison_expression_cmp - ) - for child1_operand in child1.operands - ): - to_delete.add(j) + if to_delete: + changed = True - if to_delete: - changed = True - - for i in reversed(sorted(to_delete)): - del ast.operands[i] + for i in reversed(sorted(to_delete)): + del ast.operands[i] return ast, changed + def transform_or(self, ast): + return self.__transform(ast) + + def transform_and(self, ast): + return self.__transform(ast) + class DNFTransformer(ComparisonExpressionTransformer): """ @@ -329,3 +352,26 @@ class DNFTransformer(ComparisonExpressionTransformer): result = ast return result, changed + + +class SpecialValueCanonicalization(ComparisonExpressionTransformer): + """ + Try to find particular leaf-node comparison expressions whose rhs (i.e. the + constant) can be canonicalized. This is an idiosyncratic transformation + based on some ideas people had for context-sensitive semantic equivalence + in constant values. + """ + def transform_comparison(self, ast): + if ast.lhs.object_type_name == "windows-registry-key": + windows_reg_key(ast) + + elif ast.lhs.object_type_name == "ipv4-addr": + ipv4_addr(ast) + + elif ast.lhs.object_type_name == "ipv6-addr": + ipv6_addr(ast) + + # Hard-code False here since this particular canonicalization is never + # worth doing more than once. I think it's okay to pretend nothing has + # changed. + return ast, False diff --git a/stix2/equivalence/patterns/transform/observation.py b/stix2/equivalence/patterns/transform/observation.py index 122a219..4470706 100644 --- a/stix2/equivalence/patterns/transform/observation.py +++ b/stix2/equivalence/patterns/transform/observation.py @@ -15,7 +15,8 @@ from stix2.equivalence.patterns.transform.comparison import ( FlattenTransformer as CFlattenTransformer, OrderDedupeTransformer as COrderDedupeTransformer, AbsorptionTransformer as CAbsorptionTransformer, - DNFTransformer as CDNFTransformer + DNFTransformer as CDNFTransformer, + SpecialValueCanonicalization ) from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in from stix2.equivalence.patterns.compare.observation import observation_expression_cmp @@ -473,9 +474,10 @@ class CanonicalizeComparisonExpressionsTransformer( simplify = ChainTransformer(comp_flatten, comp_order, comp_absorb) settle_simplify = SettleTransformer(simplify) + comp_special = SpecialValueCanonicalization() comp_dnf = CDNFTransformer() self.__comp_canonicalize = ChainTransformer( - settle_simplify, comp_dnf, settle_simplify + comp_special, settle_simplify, comp_dnf, settle_simplify ) def transform_observation(self, ast): diff --git a/stix2/equivalence/patterns/transform/specials.py b/stix2/equivalence/patterns/transform/specials.py new file mode 100644 index 0000000..c565e27 --- /dev/null +++ b/stix2/equivalence/patterns/transform/specials.py @@ -0,0 +1,215 @@ +""" +Some simple comparison expression canonicalization functions. +""" +import socket +from stix2.equivalence.patterns.compare.comparison import ( + object_path_to_raw_values +) + + +# Values we can use as wildcards in path patterns +_ANY_IDX = object() +_ANY_KEY = object() +_ANY = object() + + +def _path_is(object_path, path_pattern): + """ + Compare an object path against a pattern. This enables simple path + recognition based on a pattern, which is slightly more flexible than exact + equality: it supports some simple wildcards. + + The path pattern must be an iterable of values: strings for key path steps, + ints or "*" for index path steps, or wildcards. Exact matches are required + for non-wildcards in the pattern. For the wildcards, _ANY_IDX matches any + index path step; _ANY_KEY matches any key path step, and _ANY matches any + path step. + + :param object_path: An ObjectPath instance + :param path_pattern: An iterable giving the pattern path steps + :return: True if the path matches the pattern; False if not + """ + path_values = object_path_to_raw_values(object_path) + + path_iter = iter(path_values) + patt_iter = iter(path_pattern) + + result = True + while True: + path_val = next(path_iter, None) + patt_val = next(patt_iter, None) + + if path_val is None and patt_val is None: + # equal length sequences; no differences found + break + + elif path_val is None or patt_val is None: + # unequal length sequences + result = False + break + + elif patt_val is _ANY_IDX: + if not isinstance(path_val, int) and path_val != "*": + result = False + break + + elif patt_val is _ANY_KEY: + if not isinstance(path_val, str): + result = False + break + + elif patt_val is not _ANY and patt_val != path_val: + result = False + break + + return result + + +def _mask_bytes(ip_bytes, prefix_size): + """ + Retain the high-order 'prefix_size' bits from ip_bytes, and zero out the + remaining low-order bits. This side-effects ip_bytes. + + :param ip_bytes: A mutable byte sequence (e.g. a bytearray) + :param prefix_size: An integer prefix size + """ + addr_size_bytes = len(ip_bytes) + addr_size_bits = 8 * addr_size_bytes + + assert 0 <= prefix_size <= addr_size_bits + + num_fixed_bytes = prefix_size // 8 + num_zero_bytes = (addr_size_bits - prefix_size) // 8 + + if num_zero_bytes > 0: + ip_bytes[addr_size_bytes - num_zero_bytes:] = b"\x00" * num_zero_bytes + + if num_fixed_bytes + num_zero_bytes != addr_size_bytes: + # The address boundary doesn't fall on a byte boundary. + # So we have a byte for which we have to zero out some + # bits. + num_1_bits = prefix_size % 8 + mask = ((1 << num_1_bits) - 1) << (8 - num_1_bits) + ip_bytes[num_fixed_bytes] &= mask + + +def windows_reg_key(comp_expr): + """ + Lower-cases the rhs, depending on the windows-registry-key property + being compared. This enables case-insensitive comparisons between two + patterns, for those values. This side-effects the given AST. + + :param comp_expr: A _ComparisonExpression object whose type is + windows-registry-key + """ + if _path_is(comp_expr.lhs, ("key",)) \ + or _path_is(comp_expr.lhs, ("values", _ANY_IDX, "name")): + comp_expr.rhs.value = comp_expr.rhs.value.lower() + + +def ipv4_addr(comp_expr): + """ + Canonicalizes a CIDR IPv4 address by zeroing out low-order bits, according + to the prefix size. This affects the rhs when the "value" property of an + ipv4-addr is being compared. If the prefix size is 32, the size suffix is + simply dropped since it's redundant. If the value is not a valid CIDR + address, then no change is made. This also runs the address through the + platform's IPv4 address processing functions (inet_aton() and inet_ntoa()), + which can adjust the format. + + This side-effects the given AST. + + :param comp_expr: A _ComparisonExpression object whose type is ipv4-addr. + """ + if _path_is(comp_expr.lhs, ("value",)): + value = comp_expr.rhs.value + slash_idx = value.find("/") + + if 0 <= slash_idx < len(value)-1: + ip_str = value[:slash_idx] + try: + ip_bytes = socket.inet_aton(ip_str) + except OSError: + # illegal IPv4 address string + return + + try: + prefix_size = int(value[slash_idx+1:]) + except ValueError: + # illegal prefix size + return + + if prefix_size < 0 or prefix_size > 32: + # illegal prefix size + return + + if prefix_size == 32: + # Drop the "32" since it's redundant. Run the address bytes + # through inet_ntoa() in case it would adjust the format (e.g. + # drop leading zeros: 1.2.3.004 => 1.2.3.4). + value = socket.inet_ntoa(ip_bytes) + + else: + # inet_aton() gives an immutable 'bytes' value; we need a value + # we can change. + ip_bytes = bytearray(ip_bytes) + _mask_bytes(ip_bytes, prefix_size) + + ip_str = socket.inet_ntoa(ip_bytes) + value = ip_str + "/" + str(prefix_size) + + comp_expr.rhs.value = value + + +def ipv6_addr(comp_expr): + """ + Canonicalizes a CIDR IPv6 address by zeroing out low-order bits, according + to the prefix size. This affects the rhs when the "value" property of an + ipv6-addr is being compared. If the prefix size is 128, the size suffix is + simply dropped since it's redundant. If the value is not a valid CIDR + address, then no change is made. This also runs the address through the + platform's IPv6 address processing functions (inet_pton() and inet_ntop()), + which can adjust the format. + + This side-effects the given AST. + + :param comp_expr: A _ComparisonExpression object whose type is ipv6-addr. + """ + if _path_is(comp_expr.lhs, ("value",)): + value = comp_expr.rhs.value + slash_idx = value.find("/") + + if 0 <= slash_idx < len(value)-1: + ip_str = value[:slash_idx] + try: + ip_bytes = socket.inet_pton(socket.AF_INET6, ip_str) + except OSError: + # illegal IPv6 address string + return + + try: + prefix_size = int(value[slash_idx+1:]) + except ValueError: + # illegal prefix size + return + + if prefix_size < 0 or prefix_size > 128: + # illegal prefix size + return + + if prefix_size == 128: + # Drop the "128" since it's redundant. Run the IP address + # through inet_ntop() so it can reformat with the double-colons + # (and make any other adjustments) if necessary. + value = socket.inet_ntop(socket.AF_INET6, ip_bytes) + + else: + # inet_pton() gives an immutable 'bytes' value; we need a value + # we can change. + ip_bytes = bytearray(ip_bytes) + _mask_bytes(ip_bytes, prefix_size) + + ip_str = socket.inet_ntop(socket.AF_INET6, ip_bytes) + value = ip_str + "/" + str(prefix_size) + + comp_expr.rhs.value = value