Add first cut of a pattern equivalence capability
							parent
							
								
									1948b38eec
								
							
						
					
					
						commit
						311fe38cea
					
				|  | @ -0,0 +1,72 @@ | |||
| import stix2.pattern_visitor | ||||
| from stix2.equivalence.patterns.transform import ( | ||||
|     ChainTransformer, SettleTransformer | ||||
| ) | ||||
| from stix2.equivalence.patterns.compare.observation import ( | ||||
|     observation_expression_cmp | ||||
| ) | ||||
| from stix2.equivalence.patterns.transform.observation import ( | ||||
|     CanonicalizeComparisonExpressionsTransformer, | ||||
|     AbsorptionTransformer, | ||||
|     FlattenTransformer, | ||||
|     DNFTransformer, | ||||
|     OrderDedupeTransformer | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # Lazy-initialize | ||||
| _pattern_canonicalizer = None | ||||
| 
 | ||||
| 
 | ||||
| def _get_pattern_canonicalizer(): | ||||
|     """ | ||||
|     Get a canonicalization transformer for STIX patterns. | ||||
| 
 | ||||
|     :return: The transformer | ||||
|     """ | ||||
| 
 | ||||
|     # The transformers are either stateless or contain no state which changes | ||||
|     # with each use.  So we can setup the transformers once and keep reusing | ||||
|     # them. | ||||
|     global _pattern_canonicalizer | ||||
| 
 | ||||
|     if not _pattern_canonicalizer: | ||||
|         canonicalize_comp_expr = \ | ||||
|             CanonicalizeComparisonExpressionsTransformer() | ||||
| 
 | ||||
|         obs_expr_flatten = FlattenTransformer() | ||||
|         obs_expr_order = OrderDedupeTransformer() | ||||
|         obs_expr_absorb = AbsorptionTransformer() | ||||
|         obs_simplify = ChainTransformer( | ||||
|             obs_expr_flatten, obs_expr_order, obs_expr_absorb | ||||
|         ) | ||||
|         obs_settle_simplify = SettleTransformer(obs_simplify) | ||||
| 
 | ||||
|         obs_dnf = DNFTransformer() | ||||
| 
 | ||||
|         _pattern_canonicalizer = ChainTransformer( | ||||
|             canonicalize_comp_expr, | ||||
|             obs_settle_simplify, obs_dnf, obs_settle_simplify | ||||
|         ) | ||||
| 
 | ||||
|     return _pattern_canonicalizer | ||||
| 
 | ||||
| 
 | ||||
| def equivalent_patterns(pattern1, pattern2): | ||||
|     """ | ||||
|     Determine whether two STIX patterns are semantically equivalent. | ||||
| 
 | ||||
|     :param pattern1: The first STIX pattern | ||||
|     :param pattern2: The second STIX pattern | ||||
|     :return: True if the patterns are semantically equivalent; False if not | ||||
|     """ | ||||
|     patt_ast1 = stix2.pattern_visitor.create_pattern_object(pattern1) | ||||
|     patt_ast2 = stix2.pattern_visitor.create_pattern_object(pattern2) | ||||
| 
 | ||||
|     pattern_canonicalizer = _get_pattern_canonicalizer() | ||||
|     canon_patt1, _ = pattern_canonicalizer.transform(patt_ast1) | ||||
|     canon_patt2, _ = pattern_canonicalizer.transform(patt_ast2) | ||||
| 
 | ||||
|     result = observation_expression_cmp(canon_patt1, canon_patt2) | ||||
| 
 | ||||
|     return result == 0 | ||||
|  | @ -0,0 +1,90 @@ | |||
| """ | ||||
| Some generic comparison utility functions. | ||||
| """ | ||||
| 
 | ||||
| def generic_cmp(value1, value2): | ||||
|     """ | ||||
|     Generic comparator of values which uses the builtin '<' and '>' operators. | ||||
|     Assumes the values can be compared that way. | ||||
| 
 | ||||
|     :param value1: The first value | ||||
|     :param value2: The second value | ||||
|     :return: -1, 0, or 1 depending on whether value1 is less, equal, or greater | ||||
|         than value2 | ||||
|     """ | ||||
| 
 | ||||
|     return -1 if value1 < value2 else 1 if value1 > value2 else 0 | ||||
| 
 | ||||
| 
 | ||||
| def iter_lex_cmp(seq1, seq2, cmp): | ||||
|     """ | ||||
|     Generic lexicographical compare function, which works on two iterables and | ||||
|     a comparator function. | ||||
| 
 | ||||
|     :param seq1: The first iterable | ||||
|     :param seq2: The second iterable | ||||
|     :param cmp: a two-arg callable comparator for values iterated over.  It | ||||
|         must behave analogously to this function, returning <0, 0, or >0 to | ||||
|         express the ordering of the two values. | ||||
|     :return: <0 if seq1 < seq2; >0 if seq1 > seq2; 0 if they're equal | ||||
|     """ | ||||
| 
 | ||||
|     it1 = iter(seq1) | ||||
|     it2 = iter(seq2) | ||||
| 
 | ||||
|     it1_exhausted = it2_exhausted = False | ||||
|     while True: | ||||
|         try: | ||||
|             val1 = next(it1) | ||||
|         except StopIteration: | ||||
|             it1_exhausted = True | ||||
| 
 | ||||
|         try: | ||||
|             val2 = next(it2) | ||||
|         except StopIteration: | ||||
|             it2_exhausted = True | ||||
| 
 | ||||
|         # same length, all elements equal | ||||
|         if it1_exhausted and it2_exhausted: | ||||
|             result = 0 | ||||
|             break | ||||
| 
 | ||||
|         # one is a prefix of the other; the shorter one is less | ||||
|         elif it1_exhausted: | ||||
|             result = -1 | ||||
|             break | ||||
| 
 | ||||
|         elif it2_exhausted: | ||||
|             result = 1 | ||||
|             break | ||||
| 
 | ||||
|         # neither is exhausted; check values | ||||
|         else: | ||||
|             val_cmp = cmp(val1, val2) | ||||
| 
 | ||||
|             if val_cmp != 0: | ||||
|                 result = val_cmp | ||||
|                 break | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def iter_in(value, seq, cmp): | ||||
|     """ | ||||
|     A function behaving like the "in" Python operator, but which works with a | ||||
|     a comparator function.  This function checks whether the given value is | ||||
|     contained in the given iterable. | ||||
| 
 | ||||
|     :param value: A value | ||||
|     :param seq: An iterable | ||||
|     :param cmp: A 2-arg comparator function which must return 0 if the args | ||||
|         are equal | ||||
|     :return: True if the value is found in the iterable, False if it is not | ||||
|     """ | ||||
|     result = False | ||||
|     for seq_val in seq: | ||||
|         if cmp(value, seq_val) == 0: | ||||
|             result = True | ||||
|             break | ||||
| 
 | ||||
|     return result | ||||
|  | @ -0,0 +1,351 @@ | |||
| """ | ||||
| Comparison utilities for STIX pattern comparison expressions. | ||||
| """ | ||||
| import base64 | ||||
| import functools | ||||
| from stix2.patterns import ( | ||||
|     _ComparisonExpression, AndBooleanExpression, OrBooleanExpression, | ||||
|     ListObjectPathComponent, IntegerConstant, FloatConstant, StringConstant, | ||||
|     BooleanConstant, TimestampConstant, HexConstant, BinaryConstant, | ||||
|     ListConstant | ||||
| ) | ||||
| from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp | ||||
| 
 | ||||
| 
 | ||||
| _COMPARISON_OP_ORDER = ( | ||||
|     "=", "!=", "<>", "<", "<=", ">", ">=", | ||||
|     "IN", "LIKE", "MATCHES", "ISSUBSET", "ISSUPERSET" | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _CONSTANT_TYPE_ORDER = ( | ||||
|     # ints/floats come first, but have special handling since the types are | ||||
|     # treated equally as a generic "number" type.  So they aren't in this list. | ||||
|     # See constant_cmp(). | ||||
|     StringConstant, BooleanConstant, | ||||
|     TimestampConstant, HexConstant, BinaryConstant, ListConstant | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def generic_constant_cmp(const1, const2): | ||||
|     """ | ||||
|     Generic comparator for most _Constant instances.  They must have a "value" | ||||
|     attribute whose value supports the builtin comparison operators. | ||||
| 
 | ||||
|     :param const1: The first _Constant instance | ||||
|     :param const2: The second _Constant instance | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     return generic_cmp(const1.value, const2.value) | ||||
| 
 | ||||
| 
 | ||||
| def bool_cmp(value1, value2): | ||||
|     """ | ||||
|     Compare two boolean constants. | ||||
| 
 | ||||
|     :param value1: The first BooleanConstant instance | ||||
|     :param value2: The second BooleanConstant instance | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
| 
 | ||||
|     # unwrap from _Constant instances | ||||
|     value1 = value1.value | ||||
|     value2 = value2.value | ||||
| 
 | ||||
|     if (value1 and value2) or (not value1 and not value2): | ||||
|         result = 0 | ||||
| 
 | ||||
|     # Let's say... True < False? | ||||
|     elif value1: | ||||
|         result = -1 | ||||
| 
 | ||||
|     else: | ||||
|         result = 1 | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def hex_cmp(value1, value2): | ||||
|     """ | ||||
|     Compare two STIX "hex" values.  This decodes to bytes and compares that. | ||||
|     It does *not* do a string compare on the hex representations. | ||||
| 
 | ||||
|     :param value1: The first HexConstant | ||||
|     :param value2: The second HexConstant | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     bytes1 = bytes.fromhex(value1.value) | ||||
|     bytes2 = bytes.fromhex(value2.value) | ||||
| 
 | ||||
|     return generic_cmp(bytes1, bytes2) | ||||
| 
 | ||||
| 
 | ||||
| def bin_cmp(value1, value2): | ||||
|     """ | ||||
|     Compare two STIX "binary" values.  This decodes to bytes and compares that. | ||||
|     It does *not* do a string compare on the base64 representations. | ||||
| 
 | ||||
|     :param value1: The first BinaryConstant | ||||
|     :param value2: The second BinaryConstant | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     bytes1 = base64.standard_b64decode(value1.value) | ||||
|     bytes2 = base64.standard_b64decode(value2.value) | ||||
| 
 | ||||
|     return generic_cmp(bytes1, bytes2) | ||||
| 
 | ||||
| 
 | ||||
| def list_cmp(value1, value2): | ||||
|     """ | ||||
|     Compare lists order-insensitively. | ||||
| 
 | ||||
|     :param value1: The first ListConstant | ||||
|     :param value2: The second ListConstant | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
| 
 | ||||
|     # Achieve order-independence by sorting the lists first. | ||||
|     sorted_value1 = sorted( | ||||
|         value1.value, key=functools.cmp_to_key(constant_cmp) | ||||
|     ) | ||||
| 
 | ||||
|     sorted_value2 = sorted( | ||||
|         value2.value, key=functools.cmp_to_key(constant_cmp) | ||||
|     ) | ||||
| 
 | ||||
|     result = iter_lex_cmp(sorted_value1, sorted_value2, constant_cmp) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| _CONSTANT_COMPARATORS = { | ||||
|     # We have special handling for ints/floats, so no entries for those AST | ||||
|     # classes here.  See constant_cmp(). | ||||
|     StringConstant: generic_constant_cmp, | ||||
|     BooleanConstant: bool_cmp, | ||||
|     TimestampConstant: generic_constant_cmp, | ||||
|     HexConstant: hex_cmp, | ||||
|     BinaryConstant: bin_cmp, | ||||
|     ListConstant: list_cmp | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def object_path_component_cmp(comp1, comp2): | ||||
|     """ | ||||
|     Compare a string/int to another string/int; this induces an ordering over | ||||
|     all strings and ints.  It is used to perform a lexicographical sort on | ||||
|     object paths. | ||||
| 
 | ||||
|     Ints and strings compare as usual to each other; ints compare less than | ||||
|     strings. | ||||
| 
 | ||||
|     :param comp1: An object path component (string or int) | ||||
|     :param comp2: An object path component (string or int) | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
| 
 | ||||
|     # both ints or both strings: use builtin comparison operators | ||||
|     if (isinstance(comp1, int) and isinstance(comp2, int)) \ | ||||
|             or (isinstance(comp1, str) and isinstance(comp2, str)): | ||||
|         result = generic_cmp(comp1, comp2) | ||||
| 
 | ||||
|     # one is int, one is string.  Let's say ints come before strings. | ||||
|     elif isinstance(comp1, int): | ||||
|         result = -1 | ||||
| 
 | ||||
|     else: | ||||
|         result = 1 | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def object_path_to_raw_values(path): | ||||
|     """ | ||||
|     Converts the given ObjectPath instance to a list of strings and ints. | ||||
|     All property names become strings, regardless of whether they're *_ref | ||||
|     properties; "*" index steps become that string; and numeric index steps | ||||
|     become integers. | ||||
| 
 | ||||
|     :param path: An ObjectPath instance | ||||
|     :return: A generator iterator over the values | ||||
|     """ | ||||
| 
 | ||||
|     for comp in path.property_path: | ||||
|         if isinstance(comp, ListObjectPathComponent): | ||||
|             yield comp.property_name | ||||
| 
 | ||||
|             if comp.index == "*" or isinstance(comp.index, int): | ||||
|                 yield comp.index | ||||
|             else: | ||||
|                 # in case the index is a stringified int; convert to an actual | ||||
|                 # int | ||||
|                 yield int(comp.index) | ||||
| 
 | ||||
|         else: | ||||
|             yield comp.property_name | ||||
| 
 | ||||
| 
 | ||||
| def object_path_cmp(path1, path2): | ||||
|     """ | ||||
|     Compare two object paths. | ||||
| 
 | ||||
|     :param path1: The first ObjectPath instance | ||||
|     :param path2: The second ObjectPath instance | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     if path1.object_type_name < path2.object_type_name: | ||||
|         result = -1 | ||||
| 
 | ||||
|     elif path1.object_type_name > path2.object_type_name: | ||||
|         result = 1 | ||||
| 
 | ||||
|     else: | ||||
|         # I always thought of key and index path steps as separate.  The AST | ||||
|         # lumps indices in with the previous key as a single path component. | ||||
|         # The following splits the path components into individual comparable | ||||
|         # values again.  Maybe I should not do this... | ||||
|         path_vals1 = object_path_to_raw_values(path1) | ||||
|         path_vals2 = object_path_to_raw_values(path2) | ||||
|         result = iter_lex_cmp( | ||||
|             path_vals1, path_vals2, object_path_component_cmp | ||||
|         ) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def comparison_operator_cmp(op1, op2): | ||||
|     """ | ||||
|     Compare two comparison operators. | ||||
| 
 | ||||
|     :param op1: The first comparison operator (a string) | ||||
|     :param op2: The second comparison operator (a string) | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     op1_idx = _COMPARISON_OP_ORDER.index(op1) | ||||
|     op2_idx = _COMPARISON_OP_ORDER.index(op2) | ||||
| 
 | ||||
|     result = generic_cmp(op1_idx, op2_idx) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def constant_cmp(value1, value2): | ||||
|     """ | ||||
|     Compare two constants. | ||||
| 
 | ||||
|     :param value1: The first _Constant instance | ||||
|     :param value2: The second _Constant instance | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
| 
 | ||||
|     # Special handling for ints/floats: treat them generically as numbers, | ||||
|     # ordered before all other types. | ||||
|     if isinstance(value1, (IntegerConstant, FloatConstant)) \ | ||||
|             and isinstance(value2, (IntegerConstant, FloatConstant)): | ||||
|         result = generic_constant_cmp(value1, value2) | ||||
| 
 | ||||
|     elif isinstance(value1, (IntegerConstant, FloatConstant)): | ||||
|         result = -1 | ||||
| 
 | ||||
|     elif isinstance(value2, (IntegerConstant, FloatConstant)): | ||||
|         result = 1 | ||||
| 
 | ||||
|     else: | ||||
| 
 | ||||
|         type1 = type(value1) | ||||
|         type2 = type(value2) | ||||
| 
 | ||||
|         type1_idx = _CONSTANT_TYPE_ORDER.index(type1) | ||||
|         type2_idx = _CONSTANT_TYPE_ORDER.index(type2) | ||||
| 
 | ||||
|         result = generic_cmp(type1_idx, type2_idx) | ||||
|         if result == 0: | ||||
|             # Types are the same; must compare values | ||||
|             cmp_func = _CONSTANT_COMPARATORS.get(type1) | ||||
|             if not cmp_func: | ||||
|                 raise TypeError("Don't know how to compare " + type1.__name__) | ||||
| 
 | ||||
|             result = cmp_func(value1, value2) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def simple_comparison_expression_cmp(expr1, expr2): | ||||
|     """ | ||||
|     Compare "simple" comparison expressions: those which aren't AND/OR | ||||
|     combinations, just <path> <op> <value> comparisons. | ||||
| 
 | ||||
|     :param expr1: first _ComparisonExpression instance | ||||
|     :param expr2: second _ComparisonExpression instance | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
| 
 | ||||
|     result = object_path_cmp(expr1.lhs, expr2.lhs) | ||||
| 
 | ||||
|     if result == 0: | ||||
|         result = comparison_operator_cmp(expr1.operator, expr2.operator) | ||||
| 
 | ||||
|     if result == 0: | ||||
|         # _ComparisonExpression's have a "negated" attribute.  Umm... | ||||
|         # non-negated < negated? | ||||
|         if not expr1.negated and expr2.negated: | ||||
|             result = -1 | ||||
|         elif expr1.negated and not expr2.negated: | ||||
|             result = 1 | ||||
| 
 | ||||
|     if result == 0: | ||||
|         result = constant_cmp(expr1.rhs, expr2.rhs) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def comparison_expression_cmp(expr1, expr2): | ||||
|     """ | ||||
|     Compare two comparison expressions.  This is sensitive to the order of the | ||||
|     expressions' sub-components.  To achieve an order-insensitive comparison, | ||||
|     the ASTs must be canonically ordered first. | ||||
| 
 | ||||
|     :param expr1: The first comparison expression | ||||
|     :param expr2: The second comparison expression | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     if isinstance(expr1, _ComparisonExpression) \ | ||||
|             and isinstance(expr2, _ComparisonExpression): | ||||
|         result = simple_comparison_expression_cmp(expr1, expr2) | ||||
| 
 | ||||
|     # One is simple, one is compound.  Let's say... simple ones come first? | ||||
|     elif isinstance(expr1, _ComparisonExpression): | ||||
|         result = -1 | ||||
| 
 | ||||
|     elif isinstance(expr2, _ComparisonExpression): | ||||
|         result = 1 | ||||
| 
 | ||||
|     # Both are compound: AND's before OR's? | ||||
|     elif isinstance(expr1, AndBooleanExpression) \ | ||||
|             and isinstance(expr2, OrBooleanExpression): | ||||
|         result = -1 | ||||
| 
 | ||||
|     elif isinstance(expr1, OrBooleanExpression) \ | ||||
|             and isinstance(expr2, AndBooleanExpression): | ||||
|         result = 1 | ||||
| 
 | ||||
|     else: | ||||
|         # Both compound, same boolean operator: sort according to contents. | ||||
|         # This will order according to recursive invocations of this comparator, | ||||
|         # on sub-expressions. | ||||
|         result = iter_lex_cmp( | ||||
|             expr1.operands, expr2.operands, comparison_expression_cmp | ||||
|         ) | ||||
| 
 | ||||
|     return result | ||||
|  | @ -0,0 +1,124 @@ | |||
| """ | ||||
| Comparison utilities for STIX pattern observation expressions. | ||||
| """ | ||||
| from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp | ||||
| from stix2.equivalence.patterns.compare.comparison import ( | ||||
|     comparison_expression_cmp, generic_constant_cmp | ||||
| ) | ||||
| from stix2.patterns import ( | ||||
|     ObservationExpression, AndObservationExpression, OrObservationExpression, | ||||
|     QualifiedObservationExpression, _CompoundObservationExpression, | ||||
|     RepeatQualifier, WithinQualifier, StartStopQualifier, | ||||
|     FollowedByObservationExpression | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _OBSERVATION_EXPRESSION_TYPE_ORDER = ( | ||||
|     ObservationExpression, AndObservationExpression, OrObservationExpression, | ||||
|     FollowedByObservationExpression, QualifiedObservationExpression | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _QUALIFIER_TYPE_ORDER = ( | ||||
|     RepeatQualifier, WithinQualifier, StartStopQualifier | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def repeats_cmp(qual1, qual2): | ||||
|     """ | ||||
|     Compare REPEATS qualifiers.  This orders by repeat count. | ||||
|     """ | ||||
|     return generic_constant_cmp(qual1.times_to_repeat, qual2.times_to_repeat) | ||||
| 
 | ||||
| 
 | ||||
| def within_cmp(qual1, qual2): | ||||
|     """ | ||||
|     Compare WITHIN qualifiers.  This orders by number of seconds. | ||||
|     """ | ||||
|     return generic_constant_cmp( | ||||
|         qual1.number_of_seconds, qual2.number_of_seconds | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def startstop_cmp(qual1, qual2): | ||||
|     """ | ||||
|     Compare START/STOP qualifiers.  This lexicographically orders by start time, | ||||
|     then stop time. | ||||
|     """ | ||||
|     return iter_lex_cmp( | ||||
|         (qual1.start_time, qual1.stop_time), | ||||
|         (qual2.start_time, qual2.stop_time), | ||||
|         generic_constant_cmp | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| _QUALIFIER_COMPARATORS = { | ||||
|     RepeatQualifier: repeats_cmp, | ||||
|     WithinQualifier: within_cmp, | ||||
|     StartStopQualifier: startstop_cmp | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def observation_expression_cmp(expr1, expr2): | ||||
|     """ | ||||
|     Compare two observation expression ASTs.  This is sensitive to the order of | ||||
|     the expressions' sub-components.  To achieve an order-insensitive | ||||
|     comparison, the ASTs must be canonically ordered first. | ||||
| 
 | ||||
|     :param expr1: The first observation expression | ||||
|     :param expr2: The second observation expression | ||||
|     :return: <0, 0, or >0 depending on whether the first arg is less, equal or | ||||
|         greater than the second | ||||
|     """ | ||||
|     type1 = type(expr1) | ||||
|     type2 = type(expr2) | ||||
| 
 | ||||
|     type1_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type1) | ||||
|     type2_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type2) | ||||
| 
 | ||||
|     if type1_idx != type2_idx: | ||||
|         result = generic_cmp(type1_idx, type2_idx) | ||||
| 
 | ||||
|     # else, both exprs are of same type. | ||||
| 
 | ||||
|     # If they're simple, use contained comparison expression order | ||||
|     elif type1 is ObservationExpression: | ||||
|         result = comparison_expression_cmp( | ||||
|             expr1.operand, expr2.operand | ||||
|         ) | ||||
| 
 | ||||
|     elif isinstance(expr1, _CompoundObservationExpression): | ||||
|         # Both compound, and of same type (and/or/followedby): sort according | ||||
|         # to contents. | ||||
|         result = iter_lex_cmp( | ||||
|             expr1.operands, expr2.operands, observation_expression_cmp | ||||
|         ) | ||||
| 
 | ||||
|     else:  # QualifiedObservationExpression | ||||
|         # Both qualified.  Check qualifiers first; if they are the same, | ||||
|         # use order of the qualified expressions. | ||||
|         qual1_type = type(expr1.qualifier) | ||||
|         qual2_type = type(expr2.qualifier) | ||||
| 
 | ||||
|         qual1_type_idx = _QUALIFIER_TYPE_ORDER.index(qual1_type) | ||||
|         qual2_type_idx = _QUALIFIER_TYPE_ORDER.index(qual2_type) | ||||
| 
 | ||||
|         result = generic_cmp(qual1_type_idx, qual2_type_idx) | ||||
| 
 | ||||
|         if result == 0: | ||||
|             # Same qualifier type; compare qualifier details | ||||
|             qual_cmp = _QUALIFIER_COMPARATORS.get(qual1_type) | ||||
|             if qual_cmp: | ||||
|                 result = qual_cmp(expr1.qualifier, expr2.qualifier) | ||||
|             else: | ||||
|                 raise TypeError( | ||||
|                     "Can't compare qualifier type: " + qual1_type.__name__ | ||||
|                 ) | ||||
| 
 | ||||
|         if result == 0: | ||||
|             # Same qualifier type and details; use qualified expression order | ||||
|             result = observation_expression_cmp( | ||||
|                 expr1.observation_expression, expr2.observation_expression | ||||
|             ) | ||||
| 
 | ||||
|     return result | ||||
|  | @ -0,0 +1,56 @@ | |||
| """ | ||||
| Generic AST transformation classes. | ||||
| """ | ||||
| 
 | ||||
| class Transformer: | ||||
|     """ | ||||
|     Base class for AST transformers. | ||||
|     """ | ||||
|     def transform(self, ast): | ||||
|         """ | ||||
|         Transform the given AST and return the resulting AST. | ||||
| 
 | ||||
|         :param ast: The AST to transform | ||||
|         :return: A 2-tuple: the transformed AST and a boolean indicating whether | ||||
|             the transformation actually changed anything.  The change detection | ||||
|             is useful in situations where a transformation needs to be repeated | ||||
|             until the AST stops changing. | ||||
|         """ | ||||
|         raise NotImplemented("transform") | ||||
| 
 | ||||
| 
 | ||||
| class ChainTransformer(Transformer): | ||||
|     """ | ||||
|     A composite transformer which consists of a sequence of sub-transformers. | ||||
|     Applying this transformer applies all sub-transformers in sequence, as | ||||
|     a group. | ||||
|     """ | ||||
|     def __init__(self, *transformers): | ||||
|         self.__transformers = transformers | ||||
| 
 | ||||
|     def transform(self, ast): | ||||
|         changed = False | ||||
|         for transformer in self.__transformers: | ||||
|             ast, this_changed = transformer.transform(ast) | ||||
|             if this_changed: | ||||
|                 changed = True | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
| 
 | ||||
| class SettleTransformer(Transformer): | ||||
|     """ | ||||
|     A transformer that repeatedly performs a transformation until that | ||||
|     transformation no longer changes the AST.  I.e. the AST has "settled". | ||||
|     """ | ||||
|     def __init__(self, transform): | ||||
|         self.__transformer = transform | ||||
| 
 | ||||
|     def transform(self, ast): | ||||
|         changed = False | ||||
|         ast, this_changed = self.__transformer.transform(ast) | ||||
|         while this_changed: | ||||
|             changed = True | ||||
|             ast, this_changed = self.__transformer.transform(ast) | ||||
| 
 | ||||
|         return ast, changed | ||||
|  | @ -0,0 +1,331 @@ | |||
| """ | ||||
| Transformation utilities for STIX pattern comparison expressions. | ||||
| """ | ||||
| import functools | ||||
| import itertools | ||||
| from stix2.equivalence.patterns.transform import Transformer | ||||
| from stix2.patterns import ( | ||||
|     _BooleanExpression, _ComparisonExpression, AndBooleanExpression, | ||||
|     OrBooleanExpression, ParentheticalExpression | ||||
| ) | ||||
| from stix2.equivalence.patterns.compare.comparison import ( | ||||
|     comparison_expression_cmp | ||||
| ) | ||||
| from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in | ||||
| 
 | ||||
| 
 | ||||
| def _dupe_ast(ast): | ||||
|     """ | ||||
|     Create a duplicate of the given AST. | ||||
| 
 | ||||
|     Note: the comparison expression "leaves", i.e. simple <path> <op> <value> | ||||
|     comparisons are currently not duplicated.  I don't think it's necessary as | ||||
|     of this writing; they are never changed.  But revisit this if/when | ||||
|     necessary. | ||||
| 
 | ||||
|     :param ast: The AST to duplicate | ||||
|     :return: The duplicate AST | ||||
|     """ | ||||
|     if isinstance(ast, AndBooleanExpression): | ||||
|         result = AndBooleanExpression([ | ||||
|             _dupe_ast(operand) for operand in ast.operands | ||||
|         ]) | ||||
| 
 | ||||
|     elif isinstance(ast, OrBooleanExpression): | ||||
|         result = OrBooleanExpression([ | ||||
|             _dupe_ast(operand) for operand in ast.operands | ||||
|         ]) | ||||
| 
 | ||||
|     elif isinstance(ast, _ComparisonExpression): | ||||
|         # Change this to create a dupe, if we ever need to change simple | ||||
|         # comparison expressions as part of canonicalization. | ||||
|         result = ast | ||||
| 
 | ||||
|     else: | ||||
|         raise TypeError("Can't duplicate " + type(ast).__name__) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| class ComparisonExpressionTransformer(Transformer): | ||||
|     """ | ||||
|     Transformer base class with special support for transforming comparison | ||||
|     expressions.  The transform method implemented here performs a bottom-up | ||||
|     in-place transformation, with support for some comparison | ||||
|     expression-specific callbacks. | ||||
| 
 | ||||
|     Specifically, subclasses can implement methods: | ||||
|         "transform_or" for OR nodes | ||||
|         "transform_and" for AND nodes | ||||
|         "transform_default" for both types of nodes | ||||
| 
 | ||||
|     "transform_default" is a fallback, if a type-specific callback is not | ||||
|     found.  The default implementation does nothing to the AST.  The | ||||
|     type-specific callbacks are preferred over the default, if both exist. | ||||
| 
 | ||||
|     In all cases, the callbacks are called with an AST for a subtree rooted at | ||||
|     the appropriate node type, where the subtree's children have already been | ||||
|     transformed.  They must return the same thing as the base transform() | ||||
|     method: a 2-tuple with the transformed AST and a boolean for change | ||||
|     detection.  See doc for the superclass' method. | ||||
| 
 | ||||
|     This process currently silently drops parenthetical nodes, and "leaf" | ||||
|     comparison expression nodes are left unchanged. | ||||
|     """ | ||||
| 
 | ||||
|     def transform(self, ast): | ||||
|         if isinstance(ast, _BooleanExpression): | ||||
|             changed = False | ||||
|             for i, operand in enumerate(ast.operands): | ||||
|                 operand_result, this_changed = self.transform(operand) | ||||
|                 if this_changed: | ||||
|                     changed = True | ||||
| 
 | ||||
|                 ast.operands[i] = operand_result | ||||
| 
 | ||||
|             result, this_changed = self.__dispatch_transform(ast) | ||||
|             if this_changed: | ||||
|                 changed = True | ||||
| 
 | ||||
|         elif isinstance(ast, _ComparisonExpression): | ||||
|             # Terminates recursion; we don't change these nodes | ||||
|             result = ast | ||||
|             changed = False | ||||
| 
 | ||||
|         elif isinstance(ast, ParentheticalExpression): | ||||
|             # Drop these | ||||
|             result, changed = self.transform(ast.expression) | ||||
| 
 | ||||
|         else: | ||||
|             raise TypeError("Not a comparison expression: " + str(ast)) | ||||
| 
 | ||||
|         return result, changed | ||||
| 
 | ||||
|     def __dispatch_transform(self, ast): | ||||
|         """ | ||||
|         Invoke a transformer callback method based on the given ast root node | ||||
|         type. | ||||
| 
 | ||||
|         :param ast: The AST | ||||
|         :return: The callback's result | ||||
|         """ | ||||
| 
 | ||||
|         if isinstance(ast, AndBooleanExpression): | ||||
|             meth = getattr(self, "transform_and", self.transform_default) | ||||
| 
 | ||||
|         elif isinstance(ast, OrBooleanExpression): | ||||
|             meth = getattr(self, "transform_or", self.transform_default) | ||||
| 
 | ||||
|         else: | ||||
|             meth = self.transform_default | ||||
| 
 | ||||
|         return meth(ast) | ||||
| 
 | ||||
|     def transform_default(self, ast): | ||||
|         """ | ||||
|         Override to handle transforming AST nodes which don't have a more | ||||
|         specific method implemented. | ||||
|         """ | ||||
|         return ast, False | ||||
| 
 | ||||
| 
 | ||||
| class OrderDedupeTransformer( | ||||
|     ComparisonExpressionTransformer | ||||
| ): | ||||
|     """ | ||||
|     Canonically order the children of all nodes in the AST.  Because the | ||||
|     deduping algorithm is based on sorted data, this transformation also does | ||||
|     deduping. | ||||
| 
 | ||||
|     E.g.: | ||||
|         A and A => A | ||||
|         A or A => A | ||||
|     """ | ||||
| 
 | ||||
|     def transform_default(self, ast): | ||||
|         """ | ||||
|         Sort/dedupe children.  AND and OR can be treated identically. | ||||
| 
 | ||||
|         :param ast: The comparison expression AST | ||||
|         :return: The same AST node, but with sorted children | ||||
|         """ | ||||
|         sorted_children = sorted( | ||||
|             ast.operands, key=functools.cmp_to_key(comparison_expression_cmp) | ||||
|         ) | ||||
| 
 | ||||
|         deduped_children = [ | ||||
|             # Apparently when using a key function, groupby()'s "keys" are the | ||||
|             # key wrappers, not actual sequence values.  Obviously we don't | ||||
|             # need key wrappers in our ASTs! | ||||
|             k.obj for k, _ in itertools.groupby( | ||||
|                 sorted_children, key=functools.cmp_to_key( | ||||
|                     comparison_expression_cmp | ||||
|                 ) | ||||
|             ) | ||||
|         ] | ||||
| 
 | ||||
|         changed = iter_lex_cmp( | ||||
|             ast.operands, deduped_children, comparison_expression_cmp | ||||
|         ) != 0 | ||||
| 
 | ||||
|         ast.operands = deduped_children | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
| 
 | ||||
| class FlattenTransformer(ComparisonExpressionTransformer): | ||||
|     """ | ||||
|     Flatten all nodes of the AST.  E.g.: | ||||
| 
 | ||||
|         A and (B and C) => A and B and C | ||||
|         A or (B or C) => A or B or C | ||||
|         (A) => A | ||||
|     """ | ||||
| 
 | ||||
|     def transform_default(self, ast): | ||||
|         """ | ||||
|         Flatten children.  AND and OR can be treated mostly identically.  The | ||||
|         little difference is that we can absorb AND children if we're an AND | ||||
|         ourselves; and OR for OR. | ||||
| 
 | ||||
|         :param ast: The comparison expression AST | ||||
|         :return: The same AST node, but with flattened children | ||||
|         """ | ||||
| 
 | ||||
|         if isinstance(ast, _BooleanExpression) and len(ast.operands) == 1: | ||||
|             # Replace an AND/OR with one child, with the child itself. | ||||
|             ast = ast.operands[0] | ||||
|             changed = True | ||||
| 
 | ||||
|         else: | ||||
|             flat_operands = [] | ||||
|             changed = False | ||||
|             for operand in ast.operands: | ||||
|                 if isinstance(operand, _BooleanExpression) \ | ||||
|                         and ast.operator == operand.operator: | ||||
|                     flat_operands.extend(operand.operands) | ||||
|                     changed = True | ||||
| 
 | ||||
|                 else: | ||||
|                     flat_operands.append(operand) | ||||
| 
 | ||||
|             ast.operands = flat_operands | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
| 
 | ||||
| class AbsorptionTransformer( | ||||
|     ComparisonExpressionTransformer | ||||
| ): | ||||
|     """ | ||||
|     Applies boolean "absorption" rules for AST simplification.  E.g.: | ||||
| 
 | ||||
|         A and (A or B) = A | ||||
|         A or (A and B) = A | ||||
|     """ | ||||
| 
 | ||||
|     def transform_default(self, ast): | ||||
| 
 | ||||
|         changed = False | ||||
|         if isinstance(ast, _BooleanExpression): | ||||
|             secondary_op = "AND" if ast.operator == "OR" else "OR" | ||||
| 
 | ||||
|             to_delete = set() | ||||
| 
 | ||||
|             # Check i (child1) against j to see if we can delete j. | ||||
|             for i, child1 in enumerate(ast.operands): | ||||
|                 if i in to_delete: | ||||
|                     continue | ||||
| 
 | ||||
|                 for j, child2 in enumerate(ast.operands): | ||||
|                     if i == j or j in to_delete: | ||||
|                         continue | ||||
| 
 | ||||
|                     # We're checking if child1 is contained in child2, so | ||||
|                     # child2 has to be a compound object, not just a simple | ||||
|                     # comparison expression.  We also require the right operator | ||||
|                     # for child2: "AND" if ast is "OR" and vice versa. | ||||
|                     if not isinstance(child2, _BooleanExpression) \ | ||||
|                             or child2.operator != secondary_op: | ||||
|                         continue | ||||
| 
 | ||||
|                     # The simple check: is child1 contained in child2? | ||||
|                     if iter_in( | ||||
|                         child1, child2.operands, comparison_expression_cmp | ||||
|                     ): | ||||
|                         to_delete.add(j) | ||||
| 
 | ||||
|                     # A more complicated check: does child1 occur in child2 | ||||
|                     # in a "flattened" form? | ||||
|                     elif child1.operator == child2.operator: | ||||
|                         if all( | ||||
|                             iter_in( | ||||
|                                 child1_operand, child2.operands, | ||||
|                                 comparison_expression_cmp | ||||
|                             ) | ||||
|                             for child1_operand in child1.operands | ||||
|                         ): | ||||
|                             to_delete.add(j) | ||||
| 
 | ||||
|             if to_delete: | ||||
|                 changed = True | ||||
| 
 | ||||
|                 for i in reversed(sorted(to_delete)): | ||||
|                     del ast.operands[i] | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
| 
 | ||||
| class DNFTransformer(ComparisonExpressionTransformer): | ||||
|     """ | ||||
|     Convert a comparison expression AST to DNF.  E.g.: | ||||
| 
 | ||||
|         A and (B or C) => (A and B) or (A and C) | ||||
|     """ | ||||
|     def transform_and(self, ast): | ||||
|         or_children = [] | ||||
|         other_children = [] | ||||
|         changed = False | ||||
| 
 | ||||
|         # Sort AND children into two piles: the ORs and everything else | ||||
|         for child in ast.operands: | ||||
|             if isinstance(child, _BooleanExpression) and child.operator == "OR": | ||||
|                 # Need a list of operand lists, so we can compute the | ||||
|                 # product below. | ||||
|                 or_children.append(child.operands) | ||||
|             else: | ||||
|                 other_children.append(child) | ||||
| 
 | ||||
|         if or_children: | ||||
|             distributed_children = [ | ||||
|                 AndBooleanExpression([ | ||||
|                     # Make dupes: distribution implies adding repetition, and | ||||
|                     # we should ensure each repetition is independent of the | ||||
|                     # others. | ||||
|                     _dupe_ast(sub_ast) for sub_ast in itertools.chain( | ||||
|                         other_children, prod_seq | ||||
|                     ) | ||||
|                 ]) | ||||
|                 for prod_seq in itertools.product(*or_children) | ||||
|             ] | ||||
| 
 | ||||
|             # Need to recursively continue to distribute AND over OR in | ||||
|             # any of our new sub-expressions which need it.  This causes | ||||
|             # more downward recursion in the midst of this bottom-up transform. | ||||
|             # It's not good for performance.  I wonder if a top-down | ||||
|             # transformation algorithm would make more sense in this phase? | ||||
|             # But then we'd be using two different algorithms for the same | ||||
|             # thing...  Maybe this transform should be completely top-down | ||||
|             # (no bottom-up component at all)? | ||||
|             distributed_children = [ | ||||
|                 self.transform(child)[0] for child in distributed_children | ||||
|             ] | ||||
| 
 | ||||
|             result = OrBooleanExpression(distributed_children) | ||||
|             changed = True | ||||
| 
 | ||||
|         else: | ||||
|             # No AND-over-OR; nothing to do | ||||
|             result = ast | ||||
| 
 | ||||
|         return result, changed | ||||
|  | @ -0,0 +1,486 @@ | |||
| """ | ||||
| Transformation utilities for STIX pattern observation expressions. | ||||
| """ | ||||
| import functools | ||||
| import itertools | ||||
| from stix2.patterns import ( | ||||
|     ObservationExpression, AndObservationExpression, OrObservationExpression, | ||||
|     QualifiedObservationExpression, _CompoundObservationExpression, | ||||
|     ParentheticalExpression, FollowedByObservationExpression | ||||
| ) | ||||
| from stix2.equivalence.patterns.transform import ( | ||||
|     ChainTransformer, SettleTransformer, Transformer | ||||
| ) | ||||
| from stix2.equivalence.patterns.transform.comparison import ( | ||||
|     FlattenTransformer as CFlattenTransformer, | ||||
|     OrderDedupeTransformer as COrderDedupeTransformer, | ||||
|     AbsorptionTransformer as CAbsorptionTransformer, | ||||
|     DNFTransformer as CDNFTransformer | ||||
| ) | ||||
| from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in | ||||
| from stix2.equivalence.patterns.compare.observation import observation_expression_cmp | ||||
| 
 | ||||
| 
 | ||||
| def _dupe_ast(ast): | ||||
|     """ | ||||
|     Create a duplicate of the given AST.  The AST root must be an observation | ||||
|     expression of some kind (AND/OR/qualified, etc). | ||||
| 
 | ||||
|     Note: the observation expression "leaves", i.e. simple square-bracket | ||||
|     observation expressions are currently not duplicated.  I don't think it's | ||||
|     necessary as of this writing.  But revisit this if/when necessary. | ||||
| 
 | ||||
|     :param ast: The AST to duplicate | ||||
|     :return: The duplicate AST | ||||
|     """ | ||||
|     if isinstance(ast, AndObservationExpression): | ||||
|         result = AndObservationExpression([ | ||||
|             _dupe_ast(child) for child in ast.operands | ||||
|         ]) | ||||
| 
 | ||||
|     elif isinstance(ast, OrObservationExpression): | ||||
|         result = OrObservationExpression([ | ||||
|             _dupe_ast(child) for child in ast.operands | ||||
|         ]) | ||||
| 
 | ||||
|     elif isinstance(ast, FollowedByObservationExpression): | ||||
|         result = FollowedByObservationExpression([ | ||||
|             _dupe_ast(child) for child in ast.operands | ||||
|         ]) | ||||
| 
 | ||||
|     elif isinstance(ast, QualifiedObservationExpression): | ||||
|         # Don't need to dupe the qualifier object at this point | ||||
|         result = QualifiedObservationExpression( | ||||
|             _dupe_ast(ast.observation_expression), ast.qualifier | ||||
|         ) | ||||
| 
 | ||||
|     elif isinstance(ast, ObservationExpression): | ||||
|         result = ast | ||||
| 
 | ||||
|     else: | ||||
|         raise TypeError("Can't duplicate " + type(ast).__name__) | ||||
| 
 | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| class ObservationExpressionTransformer(Transformer): | ||||
|     """ | ||||
|     Transformer base class with special support for transforming observation | ||||
|     expressions.  The transform method implemented here performs a bottom-up | ||||
|     in-place transformation, with support for some observation | ||||
|     expression-specific callbacks.  It recurses down as far as the "leaf node" | ||||
|     observation expressions; it does not go inside of them, to the individual | ||||
|     components of a comparison expression. | ||||
| 
 | ||||
|     Specifically, subclasses can implement methods: | ||||
|         "transform_or" for OR nodes | ||||
|         "transform_and" for AND nodes | ||||
|         "transform_followedby" for FOLLOWEDBY nodes | ||||
|         "transform_qualified" for qualified nodes (all qualifier types) | ||||
|         "transform_observation" for "leaf" observation expression nodes | ||||
|         "transform_default" for all types of nodes | ||||
| 
 | ||||
|     "transform_default" is a fallback, if a type-specific callback is not | ||||
|     found.  The default implementation does nothing to the AST.  The | ||||
|     type-specific callbacks are preferred over the default, if both exist. | ||||
| 
 | ||||
|     In all cases, the callbacks are called with an AST for a subtree rooted at | ||||
|     the appropriate node type, where the AST's children have already been | ||||
|     transformed.  They must return the same thing as the base transform() | ||||
|     method: a 2-tuple with the transformed AST and a boolean for change | ||||
|     detection.  See doc for the superclass' method. | ||||
| 
 | ||||
|     This process currently silently drops parenthetical nodes. | ||||
|     """ | ||||
| 
 | ||||
|     # Determines how AST node types map to callback method names | ||||
|     _DISPATCH_NAME_MAP = { | ||||
|         ObservationExpression: "observation", | ||||
|         AndObservationExpression: "and", | ||||
|         OrObservationExpression: "or", | ||||
|         FollowedByObservationExpression: "followedby", | ||||
|         QualifiedObservationExpression: "qualified" | ||||
|     } | ||||
| 
 | ||||
|     def transform(self, ast): | ||||
| 
 | ||||
|         changed = False | ||||
|         if isinstance(ast, ObservationExpression): | ||||
|             # A "leaf node" for observation expressions.  We don't recurse into | ||||
|             # these. | ||||
|             result, this_changed = self.__dispatch_transform(ast) | ||||
|             if this_changed: | ||||
|                 changed = True | ||||
| 
 | ||||
|         elif isinstance(ast, _CompoundObservationExpression): | ||||
|             for i, operand in enumerate(ast.operands): | ||||
|                 result, this_changed = self.transform(operand) | ||||
|                 if this_changed: | ||||
|                     ast.operands[i] = result | ||||
|                     changed = True | ||||
| 
 | ||||
|             result, this_changed = self.__dispatch_transform(ast) | ||||
|             if this_changed: | ||||
|                 changed = True | ||||
| 
 | ||||
|         elif isinstance(ast, QualifiedObservationExpression): | ||||
|             # I don't think we need to process/transform the qualifier by | ||||
|             # itself, do we? | ||||
|             result, this_changed = self.transform(ast.observation_expression) | ||||
|             if this_changed: | ||||
|                 ast.observation_expression = result | ||||
|                 changed = True | ||||
| 
 | ||||
|             result, this_changed = self.__dispatch_transform(ast) | ||||
|             if this_changed: | ||||
|                 changed = True | ||||
| 
 | ||||
|         elif isinstance(ast, ParentheticalExpression): | ||||
|             result, _ = self.transform(ast.expression) | ||||
|             # Dropping a node is a change, right? | ||||
|             changed = True | ||||
| 
 | ||||
|         else: | ||||
|             raise TypeError("Not an observation expression: {}: {}".format( | ||||
|                 type(ast).__name__, str(ast) | ||||
|             )) | ||||
| 
 | ||||
|         return result, changed | ||||
| 
 | ||||
|     def __dispatch_transform(self, ast): | ||||
|         """ | ||||
|         Invoke a transformer callback method based on the given ast root node | ||||
|         type. | ||||
| 
 | ||||
|         :param ast: The AST | ||||
|         :return: The callback's result | ||||
|         """ | ||||
| 
 | ||||
|         dispatch_name = self._DISPATCH_NAME_MAP.get(type(ast)) | ||||
|         if dispatch_name: | ||||
|             meth_name = "transform_" + dispatch_name | ||||
|             meth = getattr(self, meth_name, self.transform_default) | ||||
|         else: | ||||
|             meth = self.transform_default | ||||
| 
 | ||||
|         return meth(ast) | ||||
| 
 | ||||
|     def transform_default(self, ast): | ||||
|         return ast, False | ||||
| 
 | ||||
| 
 | ||||
| class FlattenTransformer(ObservationExpressionTransformer): | ||||
|     """ | ||||
|     Flatten an observation expression AST.  E.g.: | ||||
| 
 | ||||
|         A and (B and C) => A and B and C | ||||
|         A or (B or C) => A or B or C | ||||
|         A followedby (B followedby C) => A followedby B followedby C | ||||
|         (A) => A | ||||
|     """ | ||||
| 
 | ||||
|     def __transform(self, ast): | ||||
| 
 | ||||
|         changed = False | ||||
| 
 | ||||
|         if len(ast.operands) == 1: | ||||
|             # Replace an AND/OR/FOLLOWEDBY with one child, with the child | ||||
|             # itself. | ||||
|             result = ast.operands[0] | ||||
|             changed = True | ||||
| 
 | ||||
|         else: | ||||
|             flat_children = [] | ||||
|             for operand in ast.operands: | ||||
|                 if isinstance(operand, _CompoundObservationExpression) \ | ||||
|                         and ast.operator == operand.operator: | ||||
|                     flat_children.extend(operand.operands) | ||||
|                     changed = True | ||||
|                 else: | ||||
|                     flat_children.append(operand) | ||||
| 
 | ||||
|             ast.operands = flat_children | ||||
|             result = ast | ||||
| 
 | ||||
|         return result, changed | ||||
| 
 | ||||
|     def transform_and(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
|     def transform_or(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
|     def transform_followedby(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
| 
 | ||||
| class OrderDedupeTransformer( | ||||
|     ObservationExpressionTransformer | ||||
| ): | ||||
|     """ | ||||
|     Canonically order AND/OR expressions, and dedupe ORs.  E.g.: | ||||
| 
 | ||||
|         A or A => A | ||||
|         B or A => A or B | ||||
|         B and A => A and B | ||||
|     """ | ||||
| 
 | ||||
|     def __transform(self, ast): | ||||
|         sorted_children = sorted( | ||||
|             ast.operands, key=functools.cmp_to_key(observation_expression_cmp) | ||||
|         ) | ||||
| 
 | ||||
|         # Deduping only applies to ORs | ||||
|         if ast.operator == "OR": | ||||
|             deduped_children = [ | ||||
|                 key.obj for key, _ in itertools.groupby( | ||||
|                     sorted_children, key=functools.cmp_to_key( | ||||
|                         observation_expression_cmp | ||||
|                     ) | ||||
|                 ) | ||||
|             ] | ||||
|         else: | ||||
|             deduped_children = sorted_children | ||||
| 
 | ||||
|         changed = iter_lex_cmp( | ||||
|             ast.operands, deduped_children, observation_expression_cmp | ||||
|         ) != 0 | ||||
| 
 | ||||
|         ast.operands = deduped_children | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
|     def transform_and(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
|     def transform_or(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
| 
 | ||||
| class AbsorptionTransformer( | ||||
|     ObservationExpressionTransformer | ||||
| ): | ||||
|     """ | ||||
|     Applies boolean "absorption" rules for observation expressions, for AST | ||||
|     simplification: | ||||
| 
 | ||||
|         A or (A and B) = A | ||||
|         A or (A followedby B) = A | ||||
| 
 | ||||
|     Other variants do not hold for observation expressions. | ||||
|     """ | ||||
| 
 | ||||
|     def __is_contained_and(self, exprs_containee, exprs_container): | ||||
|         """ | ||||
|         Determine whether the "containee" expressions are contained in the | ||||
|         "container" expressions, with AND semantics (order-independent but need | ||||
|         distinct bindings).  For example (with containee on left and container | ||||
|         on right): | ||||
| 
 | ||||
|             (A and A and B) or (A and B and C) | ||||
| 
 | ||||
|         In the above, all of the lhs vars have a counterpart in the rhs, but | ||||
|         there are two A's on the left and only one on the right.  Therefore, | ||||
|         the right does not "contain" the left.  You would need two A's on the | ||||
|         right. | ||||
| 
 | ||||
|         :param exprs_containee: The expressions we want to check for containment | ||||
|         :param exprs_container: The expressions acting as the "container" | ||||
|         :return: True if the containee is contained in the container; False if | ||||
|             not | ||||
|         """ | ||||
| 
 | ||||
|         # make our own list we are free to manipulate without affecting the | ||||
|         # function args. | ||||
|         container = list(exprs_container) | ||||
| 
 | ||||
|         result = True | ||||
|         for ee in exprs_containee: | ||||
|             for i, er in enumerate(container): | ||||
|                 if observation_expression_cmp(ee, er) == 0: | ||||
|                     # Found a match in the container; delete it so we never try | ||||
|                     # to match a container expr to two different containee | ||||
|                     # expressions. | ||||
|                     del container[i] | ||||
|                     break | ||||
|             else: | ||||
|                 result = False | ||||
|                 break | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
|     def __is_contained_followedby(self, exprs_containee, exprs_container): | ||||
|         """ | ||||
|         Determine whether the "containee" expressions are contained in the | ||||
|         "container" expressions, with FOLLOWEDBY semantics (order-sensitive and | ||||
|         need distinct bindings).  For example (with containee on left and | ||||
|         container on right): | ||||
| 
 | ||||
|             (A followedby B) or (B followedby A) | ||||
| 
 | ||||
|         In the above, all of the lhs vars have a counterpart in the rhs, but | ||||
|         the vars on the right are not in the same order.  Therefore, the right | ||||
|         does not "contain" the left.  The container vars don't have to be | ||||
|         contiguous though.  E.g. in: | ||||
| 
 | ||||
|             (A followedby B) or (D followedby A followedby C followedby B) | ||||
| 
 | ||||
|         in the container (rhs), B follows A, so it "contains" the lhs even | ||||
|         though there is other stuff mixed in. | ||||
| 
 | ||||
|         :param exprs_containee: The expressions we want to check for containment | ||||
|         :param exprs_container: The expressions acting as the "container" | ||||
|         :return: True if the containee is contained in the container; False if | ||||
|             not | ||||
|         """ | ||||
| 
 | ||||
|         ee_iter = iter(exprs_containee) | ||||
|         er_iter = iter(exprs_container) | ||||
| 
 | ||||
|         result = True | ||||
|         while True: | ||||
|             ee = next(ee_iter, None) | ||||
|             if not ee: | ||||
|                 break | ||||
| 
 | ||||
|             while True: | ||||
|                 er = next(er_iter, None) | ||||
|                 if er: | ||||
|                     if observation_expression_cmp(ee, er) == 0: | ||||
|                         break | ||||
|                 else: | ||||
|                     break | ||||
| 
 | ||||
|             if not er: | ||||
|                 result = False | ||||
|                 break | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
|     def transform_or(self, ast): | ||||
|         changed = False | ||||
|         to_delete = set() | ||||
|         for i, child1 in enumerate(ast.operands): | ||||
|             if i in to_delete: | ||||
|                 continue | ||||
| 
 | ||||
|             # The simplification doesn't work across qualifiers | ||||
|             if isinstance(child1, QualifiedObservationExpression): | ||||
|                 continue | ||||
| 
 | ||||
|             for j, child2 in enumerate(ast.operands): | ||||
|                 if i == j or j in to_delete: | ||||
|                     continue | ||||
| 
 | ||||
|                 if isinstance( | ||||
|                     child2, ( | ||||
|                         AndObservationExpression, | ||||
|                         FollowedByObservationExpression | ||||
|                     ) | ||||
|                 ): | ||||
|                     # The simple check: is child1 contained in child2? | ||||
|                     if iter_in( | ||||
|                         child1, child2.operands, observation_expression_cmp | ||||
|                     ): | ||||
|                         to_delete.add(j) | ||||
| 
 | ||||
|                     # A more complicated check: does child1 occur in child2 | ||||
|                     # in a "flattened" form? | ||||
|                     elif type(child1) is type(child2): | ||||
|                         if isinstance(child1, AndObservationExpression): | ||||
|                             can_simplify = self.__is_contained_and( | ||||
|                                 child1.operands, child2.operands | ||||
|                             ) | ||||
|                         else:  # child1 and 2 are followedby nodes | ||||
|                             can_simplify = self.__is_contained_followedby( | ||||
|                                 child1.operands, child2.operands | ||||
|                             ) | ||||
| 
 | ||||
|                         if can_simplify: | ||||
|                             to_delete.add(j) | ||||
| 
 | ||||
|         if to_delete: | ||||
|             changed = True | ||||
| 
 | ||||
|             for i in reversed(sorted(to_delete)): | ||||
|                 del ast.operands[i] | ||||
| 
 | ||||
|         return ast, changed | ||||
| 
 | ||||
| 
 | ||||
| class DNFTransformer(ObservationExpressionTransformer): | ||||
|     """ | ||||
|     Transform an observation expression to DNF.  This will distribute AND and | ||||
|     FOLLOWEDBY over OR: | ||||
| 
 | ||||
|         A and (B or C) => (A and B) or (A and C) | ||||
|         A followedby (B or C) => (A followedby B) or (A followedby C) | ||||
|     """ | ||||
| 
 | ||||
|     def __transform(self, ast): | ||||
| 
 | ||||
|         root_type = type(ast)  # will be AST class for AND or FOLLOWEDBY | ||||
|         changed = False | ||||
|         or_children = [] | ||||
|         other_children = [] | ||||
|         for child in ast.operands: | ||||
|             if isinstance(child, OrObservationExpression): | ||||
|                 or_children.append(child.operands) | ||||
|             else: | ||||
|                 other_children.append(child) | ||||
| 
 | ||||
|         if or_children: | ||||
|             distributed_children = [ | ||||
|                 root_type([ | ||||
|                     _dupe_ast(sub_ast) for sub_ast in itertools.chain( | ||||
|                         other_children, prod_seq | ||||
|                     ) | ||||
|                 ]) | ||||
|                 for prod_seq in itertools.product(*or_children) | ||||
|             ] | ||||
| 
 | ||||
|             # Need to recursively continue to distribute AND/FOLLOWEDBY over OR | ||||
|             # in any of our new sub-expressions which need it. | ||||
|             distributed_children = [ | ||||
|                 self.transform(child)[0] for child in distributed_children | ||||
|             ] | ||||
| 
 | ||||
|             result = OrObservationExpression(distributed_children) | ||||
|             changed = True | ||||
| 
 | ||||
|         else: | ||||
|             result = ast | ||||
| 
 | ||||
|         return result, changed | ||||
| 
 | ||||
|     def transform_and(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
|     def transform_followedby(self, ast): | ||||
|         return self.__transform(ast) | ||||
| 
 | ||||
| 
 | ||||
| class CanonicalizeComparisonExpressionsTransformer( | ||||
|     ObservationExpressionTransformer | ||||
| ): | ||||
|     """ | ||||
|     Canonicalize all comparison expressions. | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         comp_flatten = CFlattenTransformer() | ||||
|         comp_order = COrderDedupeTransformer() | ||||
|         comp_absorb = CAbsorptionTransformer() | ||||
|         simplify = ChainTransformer(comp_flatten, comp_order, comp_absorb) | ||||
|         settle_simplify = SettleTransformer(simplify) | ||||
| 
 | ||||
|         comp_dnf = CDNFTransformer() | ||||
|         self.__comp_canonicalize = ChainTransformer( | ||||
|             settle_simplify, comp_dnf, settle_simplify | ||||
|         ) | ||||
| 
 | ||||
|     def transform_observation(self, ast): | ||||
|         comp_expr = ast.operand | ||||
|         canon_comp_expr, changed = self.__comp_canonicalize.transform(comp_expr) | ||||
|         ast.operand = canon_comp_expr | ||||
| 
 | ||||
|         return ast, changed | ||||
		Loading…
	
		Reference in New Issue
	
	 Michael Chisholm
						Michael Chisholm