Add first cut of a pattern equivalence capability
parent
1948b38eec
commit
311fe38cea
|
@ -0,0 +1,72 @@
|
||||||
|
import stix2.pattern_visitor
|
||||||
|
from stix2.equivalence.patterns.transform import (
|
||||||
|
ChainTransformer, SettleTransformer
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.compare.observation import (
|
||||||
|
observation_expression_cmp
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.transform.observation import (
|
||||||
|
CanonicalizeComparisonExpressionsTransformer,
|
||||||
|
AbsorptionTransformer,
|
||||||
|
FlattenTransformer,
|
||||||
|
DNFTransformer,
|
||||||
|
OrderDedupeTransformer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy-initialize
|
||||||
|
_pattern_canonicalizer = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pattern_canonicalizer():
|
||||||
|
"""
|
||||||
|
Get a canonicalization transformer for STIX patterns.
|
||||||
|
|
||||||
|
:return: The transformer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The transformers are either stateless or contain no state which changes
|
||||||
|
# with each use. So we can setup the transformers once and keep reusing
|
||||||
|
# them.
|
||||||
|
global _pattern_canonicalizer
|
||||||
|
|
||||||
|
if not _pattern_canonicalizer:
|
||||||
|
canonicalize_comp_expr = \
|
||||||
|
CanonicalizeComparisonExpressionsTransformer()
|
||||||
|
|
||||||
|
obs_expr_flatten = FlattenTransformer()
|
||||||
|
obs_expr_order = OrderDedupeTransformer()
|
||||||
|
obs_expr_absorb = AbsorptionTransformer()
|
||||||
|
obs_simplify = ChainTransformer(
|
||||||
|
obs_expr_flatten, obs_expr_order, obs_expr_absorb
|
||||||
|
)
|
||||||
|
obs_settle_simplify = SettleTransformer(obs_simplify)
|
||||||
|
|
||||||
|
obs_dnf = DNFTransformer()
|
||||||
|
|
||||||
|
_pattern_canonicalizer = ChainTransformer(
|
||||||
|
canonicalize_comp_expr,
|
||||||
|
obs_settle_simplify, obs_dnf, obs_settle_simplify
|
||||||
|
)
|
||||||
|
|
||||||
|
return _pattern_canonicalizer
|
||||||
|
|
||||||
|
|
||||||
|
def equivalent_patterns(pattern1, pattern2):
|
||||||
|
"""
|
||||||
|
Determine whether two STIX patterns are semantically equivalent.
|
||||||
|
|
||||||
|
:param pattern1: The first STIX pattern
|
||||||
|
:param pattern2: The second STIX pattern
|
||||||
|
:return: True if the patterns are semantically equivalent; False if not
|
||||||
|
"""
|
||||||
|
patt_ast1 = stix2.pattern_visitor.create_pattern_object(pattern1)
|
||||||
|
patt_ast2 = stix2.pattern_visitor.create_pattern_object(pattern2)
|
||||||
|
|
||||||
|
pattern_canonicalizer = _get_pattern_canonicalizer()
|
||||||
|
canon_patt1, _ = pattern_canonicalizer.transform(patt_ast1)
|
||||||
|
canon_patt2, _ = pattern_canonicalizer.transform(patt_ast2)
|
||||||
|
|
||||||
|
result = observation_expression_cmp(canon_patt1, canon_patt2)
|
||||||
|
|
||||||
|
return result == 0
|
|
@ -0,0 +1,90 @@
|
||||||
|
"""
|
||||||
|
Some generic comparison utility functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def generic_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Generic comparator of values which uses the builtin '<' and '>' operators.
|
||||||
|
Assumes the values can be compared that way.
|
||||||
|
|
||||||
|
:param value1: The first value
|
||||||
|
:param value2: The second value
|
||||||
|
:return: -1, 0, or 1 depending on whether value1 is less, equal, or greater
|
||||||
|
than value2
|
||||||
|
"""
|
||||||
|
|
||||||
|
return -1 if value1 < value2 else 1 if value1 > value2 else 0
|
||||||
|
|
||||||
|
|
||||||
|
def iter_lex_cmp(seq1, seq2, cmp):
|
||||||
|
"""
|
||||||
|
Generic lexicographical compare function, which works on two iterables and
|
||||||
|
a comparator function.
|
||||||
|
|
||||||
|
:param seq1: The first iterable
|
||||||
|
:param seq2: The second iterable
|
||||||
|
:param cmp: a two-arg callable comparator for values iterated over. It
|
||||||
|
must behave analogously to this function, returning <0, 0, or >0 to
|
||||||
|
express the ordering of the two values.
|
||||||
|
:return: <0 if seq1 < seq2; >0 if seq1 > seq2; 0 if they're equal
|
||||||
|
"""
|
||||||
|
|
||||||
|
it1 = iter(seq1)
|
||||||
|
it2 = iter(seq2)
|
||||||
|
|
||||||
|
it1_exhausted = it2_exhausted = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
val1 = next(it1)
|
||||||
|
except StopIteration:
|
||||||
|
it1_exhausted = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
val2 = next(it2)
|
||||||
|
except StopIteration:
|
||||||
|
it2_exhausted = True
|
||||||
|
|
||||||
|
# same length, all elements equal
|
||||||
|
if it1_exhausted and it2_exhausted:
|
||||||
|
result = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
# one is a prefix of the other; the shorter one is less
|
||||||
|
elif it1_exhausted:
|
||||||
|
result = -1
|
||||||
|
break
|
||||||
|
|
||||||
|
elif it2_exhausted:
|
||||||
|
result = 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# neither is exhausted; check values
|
||||||
|
else:
|
||||||
|
val_cmp = cmp(val1, val2)
|
||||||
|
|
||||||
|
if val_cmp != 0:
|
||||||
|
result = val_cmp
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def iter_in(value, seq, cmp):
|
||||||
|
"""
|
||||||
|
A function behaving like the "in" Python operator, but which works with a
|
||||||
|
a comparator function. This function checks whether the given value is
|
||||||
|
contained in the given iterable.
|
||||||
|
|
||||||
|
:param value: A value
|
||||||
|
:param seq: An iterable
|
||||||
|
:param cmp: A 2-arg comparator function which must return 0 if the args
|
||||||
|
are equal
|
||||||
|
:return: True if the value is found in the iterable, False if it is not
|
||||||
|
"""
|
||||||
|
result = False
|
||||||
|
for seq_val in seq:
|
||||||
|
if cmp(value, seq_val) == 0:
|
||||||
|
result = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""
|
||||||
|
Comparison utilities for STIX pattern comparison expressions.
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import functools
|
||||||
|
from stix2.patterns import (
|
||||||
|
_ComparisonExpression, AndBooleanExpression, OrBooleanExpression,
|
||||||
|
ListObjectPathComponent, IntegerConstant, FloatConstant, StringConstant,
|
||||||
|
BooleanConstant, TimestampConstant, HexConstant, BinaryConstant,
|
||||||
|
ListConstant
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp
|
||||||
|
|
||||||
|
|
||||||
|
_COMPARISON_OP_ORDER = (
|
||||||
|
"=", "!=", "<>", "<", "<=", ">", ">=",
|
||||||
|
"IN", "LIKE", "MATCHES", "ISSUBSET", "ISSUPERSET"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CONSTANT_TYPE_ORDER = (
|
||||||
|
# ints/floats come first, but have special handling since the types are
|
||||||
|
# treated equally as a generic "number" type. So they aren't in this list.
|
||||||
|
# See constant_cmp().
|
||||||
|
StringConstant, BooleanConstant,
|
||||||
|
TimestampConstant, HexConstant, BinaryConstant, ListConstant
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generic_constant_cmp(const1, const2):
|
||||||
|
"""
|
||||||
|
Generic comparator for most _Constant instances. They must have a "value"
|
||||||
|
attribute whose value supports the builtin comparison operators.
|
||||||
|
|
||||||
|
:param const1: The first _Constant instance
|
||||||
|
:param const2: The second _Constant instance
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
return generic_cmp(const1.value, const2.value)
|
||||||
|
|
||||||
|
|
||||||
|
def bool_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Compare two boolean constants.
|
||||||
|
|
||||||
|
:param value1: The first BooleanConstant instance
|
||||||
|
:param value2: The second BooleanConstant instance
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
|
||||||
|
# unwrap from _Constant instances
|
||||||
|
value1 = value1.value
|
||||||
|
value2 = value2.value
|
||||||
|
|
||||||
|
if (value1 and value2) or (not value1 and not value2):
|
||||||
|
result = 0
|
||||||
|
|
||||||
|
# Let's say... True < False?
|
||||||
|
elif value1:
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
else:
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def hex_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Compare two STIX "hex" values. This decodes to bytes and compares that.
|
||||||
|
It does *not* do a string compare on the hex representations.
|
||||||
|
|
||||||
|
:param value1: The first HexConstant
|
||||||
|
:param value2: The second HexConstant
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
bytes1 = bytes.fromhex(value1.value)
|
||||||
|
bytes2 = bytes.fromhex(value2.value)
|
||||||
|
|
||||||
|
return generic_cmp(bytes1, bytes2)
|
||||||
|
|
||||||
|
|
||||||
|
def bin_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Compare two STIX "binary" values. This decodes to bytes and compares that.
|
||||||
|
It does *not* do a string compare on the base64 representations.
|
||||||
|
|
||||||
|
:param value1: The first BinaryConstant
|
||||||
|
:param value2: The second BinaryConstant
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
bytes1 = base64.standard_b64decode(value1.value)
|
||||||
|
bytes2 = base64.standard_b64decode(value2.value)
|
||||||
|
|
||||||
|
return generic_cmp(bytes1, bytes2)
|
||||||
|
|
||||||
|
|
||||||
|
def list_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Compare lists order-insensitively.
|
||||||
|
|
||||||
|
:param value1: The first ListConstant
|
||||||
|
:param value2: The second ListConstant
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Achieve order-independence by sorting the lists first.
|
||||||
|
sorted_value1 = sorted(
|
||||||
|
value1.value, key=functools.cmp_to_key(constant_cmp)
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_value2 = sorted(
|
||||||
|
value2.value, key=functools.cmp_to_key(constant_cmp)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = iter_lex_cmp(sorted_value1, sorted_value2, constant_cmp)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_CONSTANT_COMPARATORS = {
|
||||||
|
# We have special handling for ints/floats, so no entries for those AST
|
||||||
|
# classes here. See constant_cmp().
|
||||||
|
StringConstant: generic_constant_cmp,
|
||||||
|
BooleanConstant: bool_cmp,
|
||||||
|
TimestampConstant: generic_constant_cmp,
|
||||||
|
HexConstant: hex_cmp,
|
||||||
|
BinaryConstant: bin_cmp,
|
||||||
|
ListConstant: list_cmp
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def object_path_component_cmp(comp1, comp2):
|
||||||
|
"""
|
||||||
|
Compare a string/int to another string/int; this induces an ordering over
|
||||||
|
all strings and ints. It is used to perform a lexicographical sort on
|
||||||
|
object paths.
|
||||||
|
|
||||||
|
Ints and strings compare as usual to each other; ints compare less than
|
||||||
|
strings.
|
||||||
|
|
||||||
|
:param comp1: An object path component (string or int)
|
||||||
|
:param comp2: An object path component (string or int)
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
|
||||||
|
# both ints or both strings: use builtin comparison operators
|
||||||
|
if (isinstance(comp1, int) and isinstance(comp2, int)) \
|
||||||
|
or (isinstance(comp1, str) and isinstance(comp2, str)):
|
||||||
|
result = generic_cmp(comp1, comp2)
|
||||||
|
|
||||||
|
# one is int, one is string. Let's say ints come before strings.
|
||||||
|
elif isinstance(comp1, int):
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
else:
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def object_path_to_raw_values(path):
|
||||||
|
"""
|
||||||
|
Converts the given ObjectPath instance to a list of strings and ints.
|
||||||
|
All property names become strings, regardless of whether they're *_ref
|
||||||
|
properties; "*" index steps become that string; and numeric index steps
|
||||||
|
become integers.
|
||||||
|
|
||||||
|
:param path: An ObjectPath instance
|
||||||
|
:return: A generator iterator over the values
|
||||||
|
"""
|
||||||
|
|
||||||
|
for comp in path.property_path:
|
||||||
|
if isinstance(comp, ListObjectPathComponent):
|
||||||
|
yield comp.property_name
|
||||||
|
|
||||||
|
if comp.index == "*" or isinstance(comp.index, int):
|
||||||
|
yield comp.index
|
||||||
|
else:
|
||||||
|
# in case the index is a stringified int; convert to an actual
|
||||||
|
# int
|
||||||
|
yield int(comp.index)
|
||||||
|
|
||||||
|
else:
|
||||||
|
yield comp.property_name
|
||||||
|
|
||||||
|
|
||||||
|
def object_path_cmp(path1, path2):
|
||||||
|
"""
|
||||||
|
Compare two object paths.
|
||||||
|
|
||||||
|
:param path1: The first ObjectPath instance
|
||||||
|
:param path2: The second ObjectPath instance
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
if path1.object_type_name < path2.object_type_name:
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
elif path1.object_type_name > path2.object_type_name:
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# I always thought of key and index path steps as separate. The AST
|
||||||
|
# lumps indices in with the previous key as a single path component.
|
||||||
|
# The following splits the path components into individual comparable
|
||||||
|
# values again. Maybe I should not do this...
|
||||||
|
path_vals1 = object_path_to_raw_values(path1)
|
||||||
|
path_vals2 = object_path_to_raw_values(path2)
|
||||||
|
result = iter_lex_cmp(
|
||||||
|
path_vals1, path_vals2, object_path_component_cmp
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def comparison_operator_cmp(op1, op2):
|
||||||
|
"""
|
||||||
|
Compare two comparison operators.
|
||||||
|
|
||||||
|
:param op1: The first comparison operator (a string)
|
||||||
|
:param op2: The second comparison operator (a string)
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
op1_idx = _COMPARISON_OP_ORDER.index(op1)
|
||||||
|
op2_idx = _COMPARISON_OP_ORDER.index(op2)
|
||||||
|
|
||||||
|
result = generic_cmp(op1_idx, op2_idx)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def constant_cmp(value1, value2):
|
||||||
|
"""
|
||||||
|
Compare two constants.
|
||||||
|
|
||||||
|
:param value1: The first _Constant instance
|
||||||
|
:param value2: The second _Constant instance
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Special handling for ints/floats: treat them generically as numbers,
|
||||||
|
# ordered before all other types.
|
||||||
|
if isinstance(value1, (IntegerConstant, FloatConstant)) \
|
||||||
|
and isinstance(value2, (IntegerConstant, FloatConstant)):
|
||||||
|
result = generic_constant_cmp(value1, value2)
|
||||||
|
|
||||||
|
elif isinstance(value1, (IntegerConstant, FloatConstant)):
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
elif isinstance(value2, (IntegerConstant, FloatConstant)):
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
type1 = type(value1)
|
||||||
|
type2 = type(value2)
|
||||||
|
|
||||||
|
type1_idx = _CONSTANT_TYPE_ORDER.index(type1)
|
||||||
|
type2_idx = _CONSTANT_TYPE_ORDER.index(type2)
|
||||||
|
|
||||||
|
result = generic_cmp(type1_idx, type2_idx)
|
||||||
|
if result == 0:
|
||||||
|
# Types are the same; must compare values
|
||||||
|
cmp_func = _CONSTANT_COMPARATORS.get(type1)
|
||||||
|
if not cmp_func:
|
||||||
|
raise TypeError("Don't know how to compare " + type1.__name__)
|
||||||
|
|
||||||
|
result = cmp_func(value1, value2)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def simple_comparison_expression_cmp(expr1, expr2):
|
||||||
|
"""
|
||||||
|
Compare "simple" comparison expressions: those which aren't AND/OR
|
||||||
|
combinations, just <path> <op> <value> comparisons.
|
||||||
|
|
||||||
|
:param expr1: first _ComparisonExpression instance
|
||||||
|
:param expr2: second _ComparisonExpression instance
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = object_path_cmp(expr1.lhs, expr2.lhs)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
result = comparison_operator_cmp(expr1.operator, expr2.operator)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
# _ComparisonExpression's have a "negated" attribute. Umm...
|
||||||
|
# non-negated < negated?
|
||||||
|
if not expr1.negated and expr2.negated:
|
||||||
|
result = -1
|
||||||
|
elif expr1.negated and not expr2.negated:
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
result = constant_cmp(expr1.rhs, expr2.rhs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def comparison_expression_cmp(expr1, expr2):
|
||||||
|
"""
|
||||||
|
Compare two comparison expressions. This is sensitive to the order of the
|
||||||
|
expressions' sub-components. To achieve an order-insensitive comparison,
|
||||||
|
the ASTs must be canonically ordered first.
|
||||||
|
|
||||||
|
:param expr1: The first comparison expression
|
||||||
|
:param expr2: The second comparison expression
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
if isinstance(expr1, _ComparisonExpression) \
|
||||||
|
and isinstance(expr2, _ComparisonExpression):
|
||||||
|
result = simple_comparison_expression_cmp(expr1, expr2)
|
||||||
|
|
||||||
|
# One is simple, one is compound. Let's say... simple ones come first?
|
||||||
|
elif isinstance(expr1, _ComparisonExpression):
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
elif isinstance(expr2, _ComparisonExpression):
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
# Both are compound: AND's before OR's?
|
||||||
|
elif isinstance(expr1, AndBooleanExpression) \
|
||||||
|
and isinstance(expr2, OrBooleanExpression):
|
||||||
|
result = -1
|
||||||
|
|
||||||
|
elif isinstance(expr1, OrBooleanExpression) \
|
||||||
|
and isinstance(expr2, AndBooleanExpression):
|
||||||
|
result = 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Both compound, same boolean operator: sort according to contents.
|
||||||
|
# This will order according to recursive invocations of this comparator,
|
||||||
|
# on sub-expressions.
|
||||||
|
result = iter_lex_cmp(
|
||||||
|
expr1.operands, expr2.operands, comparison_expression_cmp
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,124 @@
|
||||||
|
"""
|
||||||
|
Comparison utilities for STIX pattern observation expressions.
|
||||||
|
"""
|
||||||
|
from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp
|
||||||
|
from stix2.equivalence.patterns.compare.comparison import (
|
||||||
|
comparison_expression_cmp, generic_constant_cmp
|
||||||
|
)
|
||||||
|
from stix2.patterns import (
|
||||||
|
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||||
|
QualifiedObservationExpression, _CompoundObservationExpression,
|
||||||
|
RepeatQualifier, WithinQualifier, StartStopQualifier,
|
||||||
|
FollowedByObservationExpression
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_OBSERVATION_EXPRESSION_TYPE_ORDER = (
|
||||||
|
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||||
|
FollowedByObservationExpression, QualifiedObservationExpression
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_QUALIFIER_TYPE_ORDER = (
|
||||||
|
RepeatQualifier, WithinQualifier, StartStopQualifier
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def repeats_cmp(qual1, qual2):
|
||||||
|
"""
|
||||||
|
Compare REPEATS qualifiers. This orders by repeat count.
|
||||||
|
"""
|
||||||
|
return generic_constant_cmp(qual1.times_to_repeat, qual2.times_to_repeat)
|
||||||
|
|
||||||
|
|
||||||
|
def within_cmp(qual1, qual2):
|
||||||
|
"""
|
||||||
|
Compare WITHIN qualifiers. This orders by number of seconds.
|
||||||
|
"""
|
||||||
|
return generic_constant_cmp(
|
||||||
|
qual1.number_of_seconds, qual2.number_of_seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def startstop_cmp(qual1, qual2):
|
||||||
|
"""
|
||||||
|
Compare START/STOP qualifiers. This lexicographically orders by start time,
|
||||||
|
then stop time.
|
||||||
|
"""
|
||||||
|
return iter_lex_cmp(
|
||||||
|
(qual1.start_time, qual1.stop_time),
|
||||||
|
(qual2.start_time, qual2.stop_time),
|
||||||
|
generic_constant_cmp
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_QUALIFIER_COMPARATORS = {
|
||||||
|
RepeatQualifier: repeats_cmp,
|
||||||
|
WithinQualifier: within_cmp,
|
||||||
|
StartStopQualifier: startstop_cmp
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def observation_expression_cmp(expr1, expr2):
|
||||||
|
"""
|
||||||
|
Compare two observation expression ASTs. This is sensitive to the order of
|
||||||
|
the expressions' sub-components. To achieve an order-insensitive
|
||||||
|
comparison, the ASTs must be canonically ordered first.
|
||||||
|
|
||||||
|
:param expr1: The first observation expression
|
||||||
|
:param expr2: The second observation expression
|
||||||
|
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||||
|
greater than the second
|
||||||
|
"""
|
||||||
|
type1 = type(expr1)
|
||||||
|
type2 = type(expr2)
|
||||||
|
|
||||||
|
type1_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type1)
|
||||||
|
type2_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type2)
|
||||||
|
|
||||||
|
if type1_idx != type2_idx:
|
||||||
|
result = generic_cmp(type1_idx, type2_idx)
|
||||||
|
|
||||||
|
# else, both exprs are of same type.
|
||||||
|
|
||||||
|
# If they're simple, use contained comparison expression order
|
||||||
|
elif type1 is ObservationExpression:
|
||||||
|
result = comparison_expression_cmp(
|
||||||
|
expr1.operand, expr2.operand
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(expr1, _CompoundObservationExpression):
|
||||||
|
# Both compound, and of same type (and/or/followedby): sort according
|
||||||
|
# to contents.
|
||||||
|
result = iter_lex_cmp(
|
||||||
|
expr1.operands, expr2.operands, observation_expression_cmp
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # QualifiedObservationExpression
|
||||||
|
# Both qualified. Check qualifiers first; if they are the same,
|
||||||
|
# use order of the qualified expressions.
|
||||||
|
qual1_type = type(expr1.qualifier)
|
||||||
|
qual2_type = type(expr2.qualifier)
|
||||||
|
|
||||||
|
qual1_type_idx = _QUALIFIER_TYPE_ORDER.index(qual1_type)
|
||||||
|
qual2_type_idx = _QUALIFIER_TYPE_ORDER.index(qual2_type)
|
||||||
|
|
||||||
|
result = generic_cmp(qual1_type_idx, qual2_type_idx)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
# Same qualifier type; compare qualifier details
|
||||||
|
qual_cmp = _QUALIFIER_COMPARATORS.get(qual1_type)
|
||||||
|
if qual_cmp:
|
||||||
|
result = qual_cmp(expr1.qualifier, expr2.qualifier)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Can't compare qualifier type: " + qual1_type.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
if result == 0:
|
||||||
|
# Same qualifier type and details; use qualified expression order
|
||||||
|
result = observation_expression_cmp(
|
||||||
|
expr1.observation_expression, expr2.observation_expression
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""
|
||||||
|
Generic AST transformation classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Transformer:
|
||||||
|
"""
|
||||||
|
Base class for AST transformers.
|
||||||
|
"""
|
||||||
|
def transform(self, ast):
|
||||||
|
"""
|
||||||
|
Transform the given AST and return the resulting AST.
|
||||||
|
|
||||||
|
:param ast: The AST to transform
|
||||||
|
:return: A 2-tuple: the transformed AST and a boolean indicating whether
|
||||||
|
the transformation actually changed anything. The change detection
|
||||||
|
is useful in situations where a transformation needs to be repeated
|
||||||
|
until the AST stops changing.
|
||||||
|
"""
|
||||||
|
raise NotImplemented("transform")
|
||||||
|
|
||||||
|
|
||||||
|
class ChainTransformer(Transformer):
|
||||||
|
"""
|
||||||
|
A composite transformer which consists of a sequence of sub-transformers.
|
||||||
|
Applying this transformer applies all sub-transformers in sequence, as
|
||||||
|
a group.
|
||||||
|
"""
|
||||||
|
def __init__(self, *transformers):
|
||||||
|
self.__transformers = transformers
|
||||||
|
|
||||||
|
def transform(self, ast):
|
||||||
|
changed = False
|
||||||
|
for transformer in self.__transformers:
|
||||||
|
ast, this_changed = transformer.transform(ast)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
|
||||||
|
class SettleTransformer(Transformer):
|
||||||
|
"""
|
||||||
|
A transformer that repeatedly performs a transformation until that
|
||||||
|
transformation no longer changes the AST. I.e. the AST has "settled".
|
||||||
|
"""
|
||||||
|
def __init__(self, transform):
|
||||||
|
self.__transformer = transform
|
||||||
|
|
||||||
|
def transform(self, ast):
|
||||||
|
changed = False
|
||||||
|
ast, this_changed = self.__transformer.transform(ast)
|
||||||
|
while this_changed:
|
||||||
|
changed = True
|
||||||
|
ast, this_changed = self.__transformer.transform(ast)
|
||||||
|
|
||||||
|
return ast, changed
|
|
@ -0,0 +1,331 @@
|
||||||
|
"""
|
||||||
|
Transformation utilities for STIX pattern comparison expressions.
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
from stix2.equivalence.patterns.transform import Transformer
|
||||||
|
from stix2.patterns import (
|
||||||
|
_BooleanExpression, _ComparisonExpression, AndBooleanExpression,
|
||||||
|
OrBooleanExpression, ParentheticalExpression
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.compare.comparison import (
|
||||||
|
comparison_expression_cmp
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in
|
||||||
|
|
||||||
|
|
||||||
|
def _dupe_ast(ast):
|
||||||
|
"""
|
||||||
|
Create a duplicate of the given AST.
|
||||||
|
|
||||||
|
Note: the comparison expression "leaves", i.e. simple <path> <op> <value>
|
||||||
|
comparisons are currently not duplicated. I don't think it's necessary as
|
||||||
|
of this writing; they are never changed. But revisit this if/when
|
||||||
|
necessary.
|
||||||
|
|
||||||
|
:param ast: The AST to duplicate
|
||||||
|
:return: The duplicate AST
|
||||||
|
"""
|
||||||
|
if isinstance(ast, AndBooleanExpression):
|
||||||
|
result = AndBooleanExpression([
|
||||||
|
_dupe_ast(operand) for operand in ast.operands
|
||||||
|
])
|
||||||
|
|
||||||
|
elif isinstance(ast, OrBooleanExpression):
|
||||||
|
result = OrBooleanExpression([
|
||||||
|
_dupe_ast(operand) for operand in ast.operands
|
||||||
|
])
|
||||||
|
|
||||||
|
elif isinstance(ast, _ComparisonExpression):
|
||||||
|
# Change this to create a dupe, if we ever need to change simple
|
||||||
|
# comparison expressions as part of canonicalization.
|
||||||
|
result = ast
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError("Can't duplicate " + type(ast).__name__)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ComparisonExpressionTransformer(Transformer):
|
||||||
|
"""
|
||||||
|
Transformer base class with special support for transforming comparison
|
||||||
|
expressions. The transform method implemented here performs a bottom-up
|
||||||
|
in-place transformation, with support for some comparison
|
||||||
|
expression-specific callbacks.
|
||||||
|
|
||||||
|
Specifically, subclasses can implement methods:
|
||||||
|
"transform_or" for OR nodes
|
||||||
|
"transform_and" for AND nodes
|
||||||
|
"transform_default" for both types of nodes
|
||||||
|
|
||||||
|
"transform_default" is a fallback, if a type-specific callback is not
|
||||||
|
found. The default implementation does nothing to the AST. The
|
||||||
|
type-specific callbacks are preferred over the default, if both exist.
|
||||||
|
|
||||||
|
In all cases, the callbacks are called with an AST for a subtree rooted at
|
||||||
|
the appropriate node type, where the subtree's children have already been
|
||||||
|
transformed. They must return the same thing as the base transform()
|
||||||
|
method: a 2-tuple with the transformed AST and a boolean for change
|
||||||
|
detection. See doc for the superclass' method.
|
||||||
|
|
||||||
|
This process currently silently drops parenthetical nodes, and "leaf"
|
||||||
|
comparison expression nodes are left unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform(self, ast):
|
||||||
|
if isinstance(ast, _BooleanExpression):
|
||||||
|
changed = False
|
||||||
|
for i, operand in enumerate(ast.operands):
|
||||||
|
operand_result, this_changed = self.transform(operand)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
ast.operands[i] = operand_result
|
||||||
|
|
||||||
|
result, this_changed = self.__dispatch_transform(ast)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
elif isinstance(ast, _ComparisonExpression):
|
||||||
|
# Terminates recursion; we don't change these nodes
|
||||||
|
result = ast
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
elif isinstance(ast, ParentheticalExpression):
|
||||||
|
# Drop these
|
||||||
|
result, changed = self.transform(ast.expression)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError("Not a comparison expression: " + str(ast))
|
||||||
|
|
||||||
|
return result, changed
|
||||||
|
|
||||||
|
def __dispatch_transform(self, ast):
|
||||||
|
"""
|
||||||
|
Invoke a transformer callback method based on the given ast root node
|
||||||
|
type.
|
||||||
|
|
||||||
|
:param ast: The AST
|
||||||
|
:return: The callback's result
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(ast, AndBooleanExpression):
|
||||||
|
meth = getattr(self, "transform_and", self.transform_default)
|
||||||
|
|
||||||
|
elif isinstance(ast, OrBooleanExpression):
|
||||||
|
meth = getattr(self, "transform_or", self.transform_default)
|
||||||
|
|
||||||
|
else:
|
||||||
|
meth = self.transform_default
|
||||||
|
|
||||||
|
return meth(ast)
|
||||||
|
|
||||||
|
def transform_default(self, ast):
|
||||||
|
"""
|
||||||
|
Override to handle transforming AST nodes which don't have a more
|
||||||
|
specific method implemented.
|
||||||
|
"""
|
||||||
|
return ast, False
|
||||||
|
|
||||||
|
|
||||||
|
class OrderDedupeTransformer(
|
||||||
|
ComparisonExpressionTransformer
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Canonically order the children of all nodes in the AST. Because the
|
||||||
|
deduping algorithm is based on sorted data, this transformation also does
|
||||||
|
deduping.
|
||||||
|
|
||||||
|
E.g.:
|
||||||
|
A and A => A
|
||||||
|
A or A => A
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_default(self, ast):
|
||||||
|
"""
|
||||||
|
Sort/dedupe children. AND and OR can be treated identically.
|
||||||
|
|
||||||
|
:param ast: The comparison expression AST
|
||||||
|
:return: The same AST node, but with sorted children
|
||||||
|
"""
|
||||||
|
sorted_children = sorted(
|
||||||
|
ast.operands, key=functools.cmp_to_key(comparison_expression_cmp)
|
||||||
|
)
|
||||||
|
|
||||||
|
deduped_children = [
|
||||||
|
# Apparently when using a key function, groupby()'s "keys" are the
|
||||||
|
# key wrappers, not actual sequence values. Obviously we don't
|
||||||
|
# need key wrappers in our ASTs!
|
||||||
|
k.obj for k, _ in itertools.groupby(
|
||||||
|
sorted_children, key=functools.cmp_to_key(
|
||||||
|
comparison_expression_cmp
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
changed = iter_lex_cmp(
|
||||||
|
ast.operands, deduped_children, comparison_expression_cmp
|
||||||
|
) != 0
|
||||||
|
|
||||||
|
ast.operands = deduped_children
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenTransformer(ComparisonExpressionTransformer):
|
||||||
|
"""
|
||||||
|
Flatten all nodes of the AST. E.g.:
|
||||||
|
|
||||||
|
A and (B and C) => A and B and C
|
||||||
|
A or (B or C) => A or B or C
|
||||||
|
(A) => A
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_default(self, ast):
|
||||||
|
"""
|
||||||
|
Flatten children. AND and OR can be treated mostly identically. The
|
||||||
|
little difference is that we can absorb AND children if we're an AND
|
||||||
|
ourselves; and OR for OR.
|
||||||
|
|
||||||
|
:param ast: The comparison expression AST
|
||||||
|
:return: The same AST node, but with flattened children
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(ast, _BooleanExpression) and len(ast.operands) == 1:
|
||||||
|
# Replace an AND/OR with one child, with the child itself.
|
||||||
|
ast = ast.operands[0]
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
flat_operands = []
|
||||||
|
changed = False
|
||||||
|
for operand in ast.operands:
|
||||||
|
if isinstance(operand, _BooleanExpression) \
|
||||||
|
and ast.operator == operand.operator:
|
||||||
|
flat_operands.extend(operand.operands)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
flat_operands.append(operand)
|
||||||
|
|
||||||
|
ast.operands = flat_operands
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
|
||||||
|
class AbsorptionTransformer(
|
||||||
|
ComparisonExpressionTransformer
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies boolean "absorption" rules for AST simplification. E.g.:
|
||||||
|
|
||||||
|
A and (A or B) = A
|
||||||
|
A or (A and B) = A
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_default(self, ast):
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
if isinstance(ast, _BooleanExpression):
|
||||||
|
secondary_op = "AND" if ast.operator == "OR" else "OR"
|
||||||
|
|
||||||
|
to_delete = set()
|
||||||
|
|
||||||
|
# Check i (child1) against j to see if we can delete j.
|
||||||
|
for i, child1 in enumerate(ast.operands):
|
||||||
|
if i in to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for j, child2 in enumerate(ast.operands):
|
||||||
|
if i == j or j in to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We're checking if child1 is contained in child2, so
|
||||||
|
# child2 has to be a compound object, not just a simple
|
||||||
|
# comparison expression. We also require the right operator
|
||||||
|
# for child2: "AND" if ast is "OR" and vice versa.
|
||||||
|
if not isinstance(child2, _BooleanExpression) \
|
||||||
|
or child2.operator != secondary_op:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The simple check: is child1 contained in child2?
|
||||||
|
if iter_in(
|
||||||
|
child1, child2.operands, comparison_expression_cmp
|
||||||
|
):
|
||||||
|
to_delete.add(j)
|
||||||
|
|
||||||
|
# A more complicated check: does child1 occur in child2
|
||||||
|
# in a "flattened" form?
|
||||||
|
elif child1.operator == child2.operator:
|
||||||
|
if all(
|
||||||
|
iter_in(
|
||||||
|
child1_operand, child2.operands,
|
||||||
|
comparison_expression_cmp
|
||||||
|
)
|
||||||
|
for child1_operand in child1.operands
|
||||||
|
):
|
||||||
|
to_delete.add(j)
|
||||||
|
|
||||||
|
if to_delete:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
for i in reversed(sorted(to_delete)):
|
||||||
|
del ast.operands[i]
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
|
||||||
|
class DNFTransformer(ComparisonExpressionTransformer):
|
||||||
|
"""
|
||||||
|
Convert a comparison expression AST to DNF. E.g.:
|
||||||
|
|
||||||
|
A and (B or C) => (A and B) or (A and C)
|
||||||
|
"""
|
||||||
|
def transform_and(self, ast):
|
||||||
|
or_children = []
|
||||||
|
other_children = []
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
# Sort AND children into two piles: the ORs and everything else
|
||||||
|
for child in ast.operands:
|
||||||
|
if isinstance(child, _BooleanExpression) and child.operator == "OR":
|
||||||
|
# Need a list of operand lists, so we can compute the
|
||||||
|
# product below.
|
||||||
|
or_children.append(child.operands)
|
||||||
|
else:
|
||||||
|
other_children.append(child)
|
||||||
|
|
||||||
|
if or_children:
|
||||||
|
distributed_children = [
|
||||||
|
AndBooleanExpression([
|
||||||
|
# Make dupes: distribution implies adding repetition, and
|
||||||
|
# we should ensure each repetition is independent of the
|
||||||
|
# others.
|
||||||
|
_dupe_ast(sub_ast) for sub_ast in itertools.chain(
|
||||||
|
other_children, prod_seq
|
||||||
|
)
|
||||||
|
])
|
||||||
|
for prod_seq in itertools.product(*or_children)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Need to recursively continue to distribute AND over OR in
|
||||||
|
# any of our new sub-expressions which need it. This causes
|
||||||
|
# more downward recursion in the midst of this bottom-up transform.
|
||||||
|
# It's not good for performance. I wonder if a top-down
|
||||||
|
# transformation algorithm would make more sense in this phase?
|
||||||
|
# But then we'd be using two different algorithms for the same
|
||||||
|
# thing... Maybe this transform should be completely top-down
|
||||||
|
# (no bottom-up component at all)?
|
||||||
|
distributed_children = [
|
||||||
|
self.transform(child)[0] for child in distributed_children
|
||||||
|
]
|
||||||
|
|
||||||
|
result = OrBooleanExpression(distributed_children)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No AND-over-OR; nothing to do
|
||||||
|
result = ast
|
||||||
|
|
||||||
|
return result, changed
|
|
@ -0,0 +1,486 @@
|
||||||
|
"""
|
||||||
|
Transformation utilities for STIX pattern observation expressions.
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
from stix2.patterns import (
|
||||||
|
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||||
|
QualifiedObservationExpression, _CompoundObservationExpression,
|
||||||
|
ParentheticalExpression, FollowedByObservationExpression
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.transform import (
|
||||||
|
ChainTransformer, SettleTransformer, Transformer
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.transform.comparison import (
|
||||||
|
FlattenTransformer as CFlattenTransformer,
|
||||||
|
OrderDedupeTransformer as COrderDedupeTransformer,
|
||||||
|
AbsorptionTransformer as CAbsorptionTransformer,
|
||||||
|
DNFTransformer as CDNFTransformer
|
||||||
|
)
|
||||||
|
from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in
|
||||||
|
from stix2.equivalence.patterns.compare.observation import observation_expression_cmp
|
||||||
|
|
||||||
|
|
||||||
|
def _dupe_ast(ast):
|
||||||
|
"""
|
||||||
|
Create a duplicate of the given AST. The AST root must be an observation
|
||||||
|
expression of some kind (AND/OR/qualified, etc).
|
||||||
|
|
||||||
|
Note: the observation expression "leaves", i.e. simple square-bracket
|
||||||
|
observation expressions are currently not duplicated. I don't think it's
|
||||||
|
necessary as of this writing. But revisit this if/when necessary.
|
||||||
|
|
||||||
|
:param ast: The AST to duplicate
|
||||||
|
:return: The duplicate AST
|
||||||
|
"""
|
||||||
|
if isinstance(ast, AndObservationExpression):
|
||||||
|
result = AndObservationExpression([
|
||||||
|
_dupe_ast(child) for child in ast.operands
|
||||||
|
])
|
||||||
|
|
||||||
|
elif isinstance(ast, OrObservationExpression):
|
||||||
|
result = OrObservationExpression([
|
||||||
|
_dupe_ast(child) for child in ast.operands
|
||||||
|
])
|
||||||
|
|
||||||
|
elif isinstance(ast, FollowedByObservationExpression):
|
||||||
|
result = FollowedByObservationExpression([
|
||||||
|
_dupe_ast(child) for child in ast.operands
|
||||||
|
])
|
||||||
|
|
||||||
|
elif isinstance(ast, QualifiedObservationExpression):
|
||||||
|
# Don't need to dupe the qualifier object at this point
|
||||||
|
result = QualifiedObservationExpression(
|
||||||
|
_dupe_ast(ast.observation_expression), ast.qualifier
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(ast, ObservationExpression):
|
||||||
|
result = ast
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError("Can't duplicate " + type(ast).__name__)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationExpressionTransformer(Transformer):
|
||||||
|
"""
|
||||||
|
Transformer base class with special support for transforming observation
|
||||||
|
expressions. The transform method implemented here performs a bottom-up
|
||||||
|
in-place transformation, with support for some observation
|
||||||
|
expression-specific callbacks. It recurses down as far as the "leaf node"
|
||||||
|
observation expressions; it does not go inside of them, to the individual
|
||||||
|
components of a comparison expression.
|
||||||
|
|
||||||
|
Specifically, subclasses can implement methods:
|
||||||
|
"transform_or" for OR nodes
|
||||||
|
"transform_and" for AND nodes
|
||||||
|
"transform_followedby" for FOLLOWEDBY nodes
|
||||||
|
"transform_qualified" for qualified nodes (all qualifier types)
|
||||||
|
"transform_observation" for "leaf" observation expression nodes
|
||||||
|
"transform_default" for all types of nodes
|
||||||
|
|
||||||
|
"transform_default" is a fallback, if a type-specific callback is not
|
||||||
|
found. The default implementation does nothing to the AST. The
|
||||||
|
type-specific callbacks are preferred over the default, if both exist.
|
||||||
|
|
||||||
|
In all cases, the callbacks are called with an AST for a subtree rooted at
|
||||||
|
the appropriate node type, where the AST's children have already been
|
||||||
|
transformed. They must return the same thing as the base transform()
|
||||||
|
method: a 2-tuple with the transformed AST and a boolean for change
|
||||||
|
detection. See doc for the superclass' method.
|
||||||
|
|
||||||
|
This process currently silently drops parenthetical nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Determines how AST node types map to callback method names
|
||||||
|
_DISPATCH_NAME_MAP = {
|
||||||
|
ObservationExpression: "observation",
|
||||||
|
AndObservationExpression: "and",
|
||||||
|
OrObservationExpression: "or",
|
||||||
|
FollowedByObservationExpression: "followedby",
|
||||||
|
QualifiedObservationExpression: "qualified"
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform(self, ast):
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
if isinstance(ast, ObservationExpression):
|
||||||
|
# A "leaf node" for observation expressions. We don't recurse into
|
||||||
|
# these.
|
||||||
|
result, this_changed = self.__dispatch_transform(ast)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
elif isinstance(ast, _CompoundObservationExpression):
|
||||||
|
for i, operand in enumerate(ast.operands):
|
||||||
|
result, this_changed = self.transform(operand)
|
||||||
|
if this_changed:
|
||||||
|
ast.operands[i] = result
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
result, this_changed = self.__dispatch_transform(ast)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
elif isinstance(ast, QualifiedObservationExpression):
|
||||||
|
# I don't think we need to process/transform the qualifier by
|
||||||
|
# itself, do we?
|
||||||
|
result, this_changed = self.transform(ast.observation_expression)
|
||||||
|
if this_changed:
|
||||||
|
ast.observation_expression = result
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
result, this_changed = self.__dispatch_transform(ast)
|
||||||
|
if this_changed:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
elif isinstance(ast, ParentheticalExpression):
|
||||||
|
result, _ = self.transform(ast.expression)
|
||||||
|
# Dropping a node is a change, right?
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError("Not an observation expression: {}: {}".format(
|
||||||
|
type(ast).__name__, str(ast)
|
||||||
|
))
|
||||||
|
|
||||||
|
return result, changed
|
||||||
|
|
||||||
|
def __dispatch_transform(self, ast):
|
||||||
|
"""
|
||||||
|
Invoke a transformer callback method based on the given ast root node
|
||||||
|
type.
|
||||||
|
|
||||||
|
:param ast: The AST
|
||||||
|
:return: The callback's result
|
||||||
|
"""
|
||||||
|
|
||||||
|
dispatch_name = self._DISPATCH_NAME_MAP.get(type(ast))
|
||||||
|
if dispatch_name:
|
||||||
|
meth_name = "transform_" + dispatch_name
|
||||||
|
meth = getattr(self, meth_name, self.transform_default)
|
||||||
|
else:
|
||||||
|
meth = self.transform_default
|
||||||
|
|
||||||
|
return meth(ast)
|
||||||
|
|
||||||
|
def transform_default(self, ast):
|
||||||
|
return ast, False
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenTransformer(ObservationExpressionTransformer):
|
||||||
|
"""
|
||||||
|
Flatten an observation expression AST. E.g.:
|
||||||
|
|
||||||
|
A and (B and C) => A and B and C
|
||||||
|
A or (B or C) => A or B or C
|
||||||
|
A followedby (B followedby C) => A followedby B followedby C
|
||||||
|
(A) => A
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __transform(self, ast):
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
if len(ast.operands) == 1:
|
||||||
|
# Replace an AND/OR/FOLLOWEDBY with one child, with the child
|
||||||
|
# itself.
|
||||||
|
result = ast.operands[0]
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
flat_children = []
|
||||||
|
for operand in ast.operands:
|
||||||
|
if isinstance(operand, _CompoundObservationExpression) \
|
||||||
|
and ast.operator == operand.operator:
|
||||||
|
flat_children.extend(operand.operands)
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
flat_children.append(operand)
|
||||||
|
|
||||||
|
ast.operands = flat_children
|
||||||
|
result = ast
|
||||||
|
|
||||||
|
return result, changed
|
||||||
|
|
||||||
|
def transform_and(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
def transform_or(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
def transform_followedby(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
|
||||||
|
class OrderDedupeTransformer(
|
||||||
|
ObservationExpressionTransformer
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Canonically order AND/OR expressions, and dedupe ORs. E.g.:
|
||||||
|
|
||||||
|
A or A => A
|
||||||
|
B or A => A or B
|
||||||
|
B and A => A and B
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __transform(self, ast):
|
||||||
|
sorted_children = sorted(
|
||||||
|
ast.operands, key=functools.cmp_to_key(observation_expression_cmp)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deduping only applies to ORs
|
||||||
|
if ast.operator == "OR":
|
||||||
|
deduped_children = [
|
||||||
|
key.obj for key, _ in itertools.groupby(
|
||||||
|
sorted_children, key=functools.cmp_to_key(
|
||||||
|
observation_expression_cmp
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
deduped_children = sorted_children
|
||||||
|
|
||||||
|
changed = iter_lex_cmp(
|
||||||
|
ast.operands, deduped_children, observation_expression_cmp
|
||||||
|
) != 0
|
||||||
|
|
||||||
|
ast.operands = deduped_children
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
def transform_and(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
def transform_or(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
|
||||||
|
class AbsorptionTransformer(
|
||||||
|
ObservationExpressionTransformer
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies boolean "absorption" rules for observation expressions, for AST
|
||||||
|
simplification:
|
||||||
|
|
||||||
|
A or (A and B) = A
|
||||||
|
A or (A followedby B) = A
|
||||||
|
|
||||||
|
Other variants do not hold for observation expressions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __is_contained_and(self, exprs_containee, exprs_container):
|
||||||
|
"""
|
||||||
|
Determine whether the "containee" expressions are contained in the
|
||||||
|
"container" expressions, with AND semantics (order-independent but need
|
||||||
|
distinct bindings). For example (with containee on left and container
|
||||||
|
on right):
|
||||||
|
|
||||||
|
(A and A and B) or (A and B and C)
|
||||||
|
|
||||||
|
In the above, all of the lhs vars have a counterpart in the rhs, but
|
||||||
|
there are two A's on the left and only one on the right. Therefore,
|
||||||
|
the right does not "contain" the left. You would need two A's on the
|
||||||
|
right.
|
||||||
|
|
||||||
|
:param exprs_containee: The expressions we want to check for containment
|
||||||
|
:param exprs_container: The expressions acting as the "container"
|
||||||
|
:return: True if the containee is contained in the container; False if
|
||||||
|
not
|
||||||
|
"""
|
||||||
|
|
||||||
|
# make our own list we are free to manipulate without affecting the
|
||||||
|
# function args.
|
||||||
|
container = list(exprs_container)
|
||||||
|
|
||||||
|
result = True
|
||||||
|
for ee in exprs_containee:
|
||||||
|
for i, er in enumerate(container):
|
||||||
|
if observation_expression_cmp(ee, er) == 0:
|
||||||
|
# Found a match in the container; delete it so we never try
|
||||||
|
# to match a container expr to two different containee
|
||||||
|
# expressions.
|
||||||
|
del container[i]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result = False
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __is_contained_followedby(self, exprs_containee, exprs_container):
|
||||||
|
"""
|
||||||
|
Determine whether the "containee" expressions are contained in the
|
||||||
|
"container" expressions, with FOLLOWEDBY semantics (order-sensitive and
|
||||||
|
need distinct bindings). For example (with containee on left and
|
||||||
|
container on right):
|
||||||
|
|
||||||
|
(A followedby B) or (B followedby A)
|
||||||
|
|
||||||
|
In the above, all of the lhs vars have a counterpart in the rhs, but
|
||||||
|
the vars on the right are not in the same order. Therefore, the right
|
||||||
|
does not "contain" the left. The container vars don't have to be
|
||||||
|
contiguous though. E.g. in:
|
||||||
|
|
||||||
|
(A followedby B) or (D followedby A followedby C followedby B)
|
||||||
|
|
||||||
|
in the container (rhs), B follows A, so it "contains" the lhs even
|
||||||
|
though there is other stuff mixed in.
|
||||||
|
|
||||||
|
:param exprs_containee: The expressions we want to check for containment
|
||||||
|
:param exprs_container: The expressions acting as the "container"
|
||||||
|
:return: True if the containee is contained in the container; False if
|
||||||
|
not
|
||||||
|
"""
|
||||||
|
|
||||||
|
ee_iter = iter(exprs_containee)
|
||||||
|
er_iter = iter(exprs_container)
|
||||||
|
|
||||||
|
result = True
|
||||||
|
while True:
|
||||||
|
ee = next(ee_iter, None)
|
||||||
|
if not ee:
|
||||||
|
break
|
||||||
|
|
||||||
|
while True:
|
||||||
|
er = next(er_iter, None)
|
||||||
|
if er:
|
||||||
|
if observation_expression_cmp(ee, er) == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not er:
|
||||||
|
result = False
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def transform_or(self, ast):
|
||||||
|
changed = False
|
||||||
|
to_delete = set()
|
||||||
|
for i, child1 in enumerate(ast.operands):
|
||||||
|
if i in to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The simplification doesn't work across qualifiers
|
||||||
|
if isinstance(child1, QualifiedObservationExpression):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for j, child2 in enumerate(ast.operands):
|
||||||
|
if i == j or j in to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
child2, (
|
||||||
|
AndObservationExpression,
|
||||||
|
FollowedByObservationExpression
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# The simple check: is child1 contained in child2?
|
||||||
|
if iter_in(
|
||||||
|
child1, child2.operands, observation_expression_cmp
|
||||||
|
):
|
||||||
|
to_delete.add(j)
|
||||||
|
|
||||||
|
# A more complicated check: does child1 occur in child2
|
||||||
|
# in a "flattened" form?
|
||||||
|
elif type(child1) is type(child2):
|
||||||
|
if isinstance(child1, AndObservationExpression):
|
||||||
|
can_simplify = self.__is_contained_and(
|
||||||
|
child1.operands, child2.operands
|
||||||
|
)
|
||||||
|
else: # child1 and 2 are followedby nodes
|
||||||
|
can_simplify = self.__is_contained_followedby(
|
||||||
|
child1.operands, child2.operands
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_simplify:
|
||||||
|
to_delete.add(j)
|
||||||
|
|
||||||
|
if to_delete:
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
for i in reversed(sorted(to_delete)):
|
||||||
|
del ast.operands[i]
|
||||||
|
|
||||||
|
return ast, changed
|
||||||
|
|
||||||
|
|
||||||
|
class DNFTransformer(ObservationExpressionTransformer):
|
||||||
|
"""
|
||||||
|
Transform an observation expression to DNF. This will distribute AND and
|
||||||
|
FOLLOWEDBY over OR:
|
||||||
|
|
||||||
|
A and (B or C) => (A and B) or (A and C)
|
||||||
|
A followedby (B or C) => (A followedby B) or (A followedby C)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __transform(self, ast):
|
||||||
|
|
||||||
|
root_type = type(ast) # will be AST class for AND or FOLLOWEDBY
|
||||||
|
changed = False
|
||||||
|
or_children = []
|
||||||
|
other_children = []
|
||||||
|
for child in ast.operands:
|
||||||
|
if isinstance(child, OrObservationExpression):
|
||||||
|
or_children.append(child.operands)
|
||||||
|
else:
|
||||||
|
other_children.append(child)
|
||||||
|
|
||||||
|
if or_children:
|
||||||
|
distributed_children = [
|
||||||
|
root_type([
|
||||||
|
_dupe_ast(sub_ast) for sub_ast in itertools.chain(
|
||||||
|
other_children, prod_seq
|
||||||
|
)
|
||||||
|
])
|
||||||
|
for prod_seq in itertools.product(*or_children)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Need to recursively continue to distribute AND/FOLLOWEDBY over OR
|
||||||
|
# in any of our new sub-expressions which need it.
|
||||||
|
distributed_children = [
|
||||||
|
self.transform(child)[0] for child in distributed_children
|
||||||
|
]
|
||||||
|
|
||||||
|
result = OrObservationExpression(distributed_children)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
result = ast
|
||||||
|
|
||||||
|
return result, changed
|
||||||
|
|
||||||
|
def transform_and(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
def transform_followedby(self, ast):
|
||||||
|
return self.__transform(ast)
|
||||||
|
|
||||||
|
|
||||||
|
class CanonicalizeComparisonExpressionsTransformer(
|
||||||
|
ObservationExpressionTransformer
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Canonicalize all comparison expressions.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
comp_flatten = CFlattenTransformer()
|
||||||
|
comp_order = COrderDedupeTransformer()
|
||||||
|
comp_absorb = CAbsorptionTransformer()
|
||||||
|
simplify = ChainTransformer(comp_flatten, comp_order, comp_absorb)
|
||||||
|
settle_simplify = SettleTransformer(simplify)
|
||||||
|
|
||||||
|
comp_dnf = CDNFTransformer()
|
||||||
|
self.__comp_canonicalize = ChainTransformer(
|
||||||
|
settle_simplify, comp_dnf, settle_simplify
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_observation(self, ast):
|
||||||
|
comp_expr = ast.operand
|
||||||
|
canon_comp_expr, changed = self.__comp_canonicalize.transform(comp_expr)
|
||||||
|
ast.operand = canon_comp_expr
|
||||||
|
|
||||||
|
return ast, changed
|
Loading…
Reference in New Issue