Add first cut of a pattern equivalence capability
parent
1948b38eec
commit
311fe38cea
|
@ -0,0 +1,72 @@
|
|||
import stix2.pattern_visitor
|
||||
from stix2.equivalence.patterns.transform import (
|
||||
ChainTransformer, SettleTransformer
|
||||
)
|
||||
from stix2.equivalence.patterns.compare.observation import (
|
||||
observation_expression_cmp
|
||||
)
|
||||
from stix2.equivalence.patterns.transform.observation import (
|
||||
CanonicalizeComparisonExpressionsTransformer,
|
||||
AbsorptionTransformer,
|
||||
FlattenTransformer,
|
||||
DNFTransformer,
|
||||
OrderDedupeTransformer
|
||||
)
|
||||
|
||||
|
||||
# Lazy-initialize
|
||||
_pattern_canonicalizer = None
|
||||
|
||||
|
||||
def _get_pattern_canonicalizer():
|
||||
"""
|
||||
Get a canonicalization transformer for STIX patterns.
|
||||
|
||||
:return: The transformer
|
||||
"""
|
||||
|
||||
# The transformers are either stateless or contain no state which changes
|
||||
# with each use. So we can setup the transformers once and keep reusing
|
||||
# them.
|
||||
global _pattern_canonicalizer
|
||||
|
||||
if not _pattern_canonicalizer:
|
||||
canonicalize_comp_expr = \
|
||||
CanonicalizeComparisonExpressionsTransformer()
|
||||
|
||||
obs_expr_flatten = FlattenTransformer()
|
||||
obs_expr_order = OrderDedupeTransformer()
|
||||
obs_expr_absorb = AbsorptionTransformer()
|
||||
obs_simplify = ChainTransformer(
|
||||
obs_expr_flatten, obs_expr_order, obs_expr_absorb
|
||||
)
|
||||
obs_settle_simplify = SettleTransformer(obs_simplify)
|
||||
|
||||
obs_dnf = DNFTransformer()
|
||||
|
||||
_pattern_canonicalizer = ChainTransformer(
|
||||
canonicalize_comp_expr,
|
||||
obs_settle_simplify, obs_dnf, obs_settle_simplify
|
||||
)
|
||||
|
||||
return _pattern_canonicalizer
|
||||
|
||||
|
||||
def equivalent_patterns(pattern1, pattern2):
|
||||
"""
|
||||
Determine whether two STIX patterns are semantically equivalent.
|
||||
|
||||
:param pattern1: The first STIX pattern
|
||||
:param pattern2: The second STIX pattern
|
||||
:return: True if the patterns are semantically equivalent; False if not
|
||||
"""
|
||||
patt_ast1 = stix2.pattern_visitor.create_pattern_object(pattern1)
|
||||
patt_ast2 = stix2.pattern_visitor.create_pattern_object(pattern2)
|
||||
|
||||
pattern_canonicalizer = _get_pattern_canonicalizer()
|
||||
canon_patt1, _ = pattern_canonicalizer.transform(patt_ast1)
|
||||
canon_patt2, _ = pattern_canonicalizer.transform(patt_ast2)
|
||||
|
||||
result = observation_expression_cmp(canon_patt1, canon_patt2)
|
||||
|
||||
return result == 0
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Some generic comparison utility functions.
|
||||
"""
|
||||
|
||||
def generic_cmp(value1, value2):
|
||||
"""
|
||||
Generic comparator of values which uses the builtin '<' and '>' operators.
|
||||
Assumes the values can be compared that way.
|
||||
|
||||
:param value1: The first value
|
||||
:param value2: The second value
|
||||
:return: -1, 0, or 1 depending on whether value1 is less, equal, or greater
|
||||
than value2
|
||||
"""
|
||||
|
||||
return -1 if value1 < value2 else 1 if value1 > value2 else 0
|
||||
|
||||
|
||||
def iter_lex_cmp(seq1, seq2, cmp):
|
||||
"""
|
||||
Generic lexicographical compare function, which works on two iterables and
|
||||
a comparator function.
|
||||
|
||||
:param seq1: The first iterable
|
||||
:param seq2: The second iterable
|
||||
:param cmp: a two-arg callable comparator for values iterated over. It
|
||||
must behave analogously to this function, returning <0, 0, or >0 to
|
||||
express the ordering of the two values.
|
||||
:return: <0 if seq1 < seq2; >0 if seq1 > seq2; 0 if they're equal
|
||||
"""
|
||||
|
||||
it1 = iter(seq1)
|
||||
it2 = iter(seq2)
|
||||
|
||||
it1_exhausted = it2_exhausted = False
|
||||
while True:
|
||||
try:
|
||||
val1 = next(it1)
|
||||
except StopIteration:
|
||||
it1_exhausted = True
|
||||
|
||||
try:
|
||||
val2 = next(it2)
|
||||
except StopIteration:
|
||||
it2_exhausted = True
|
||||
|
||||
# same length, all elements equal
|
||||
if it1_exhausted and it2_exhausted:
|
||||
result = 0
|
||||
break
|
||||
|
||||
# one is a prefix of the other; the shorter one is less
|
||||
elif it1_exhausted:
|
||||
result = -1
|
||||
break
|
||||
|
||||
elif it2_exhausted:
|
||||
result = 1
|
||||
break
|
||||
|
||||
# neither is exhausted; check values
|
||||
else:
|
||||
val_cmp = cmp(val1, val2)
|
||||
|
||||
if val_cmp != 0:
|
||||
result = val_cmp
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def iter_in(value, seq, cmp):
|
||||
"""
|
||||
A function behaving like the "in" Python operator, but which works with a
|
||||
a comparator function. This function checks whether the given value is
|
||||
contained in the given iterable.
|
||||
|
||||
:param value: A value
|
||||
:param seq: An iterable
|
||||
:param cmp: A 2-arg comparator function which must return 0 if the args
|
||||
are equal
|
||||
:return: True if the value is found in the iterable, False if it is not
|
||||
"""
|
||||
result = False
|
||||
for seq_val in seq:
|
||||
if cmp(value, seq_val) == 0:
|
||||
result = True
|
||||
break
|
||||
|
||||
return result
|
|
@ -0,0 +1,351 @@
|
|||
"""
|
||||
Comparison utilities for STIX pattern comparison expressions.
|
||||
"""
|
||||
import base64
|
||||
import functools
|
||||
from stix2.patterns import (
|
||||
_ComparisonExpression, AndBooleanExpression, OrBooleanExpression,
|
||||
ListObjectPathComponent, IntegerConstant, FloatConstant, StringConstant,
|
||||
BooleanConstant, TimestampConstant, HexConstant, BinaryConstant,
|
||||
ListConstant
|
||||
)
|
||||
from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp
|
||||
|
||||
|
||||
_COMPARISON_OP_ORDER = (
|
||||
"=", "!=", "<>", "<", "<=", ">", ">=",
|
||||
"IN", "LIKE", "MATCHES", "ISSUBSET", "ISSUPERSET"
|
||||
)
|
||||
|
||||
|
||||
_CONSTANT_TYPE_ORDER = (
|
||||
# ints/floats come first, but have special handling since the types are
|
||||
# treated equally as a generic "number" type. So they aren't in this list.
|
||||
# See constant_cmp().
|
||||
StringConstant, BooleanConstant,
|
||||
TimestampConstant, HexConstant, BinaryConstant, ListConstant
|
||||
)
|
||||
|
||||
|
||||
def generic_constant_cmp(const1, const2):
|
||||
"""
|
||||
Generic comparator for most _Constant instances. They must have a "value"
|
||||
attribute whose value supports the builtin comparison operators.
|
||||
|
||||
:param const1: The first _Constant instance
|
||||
:param const2: The second _Constant instance
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
return generic_cmp(const1.value, const2.value)
|
||||
|
||||
|
||||
def bool_cmp(value1, value2):
|
||||
"""
|
||||
Compare two boolean constants.
|
||||
|
||||
:param value1: The first BooleanConstant instance
|
||||
:param value2: The second BooleanConstant instance
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
|
||||
# unwrap from _Constant instances
|
||||
value1 = value1.value
|
||||
value2 = value2.value
|
||||
|
||||
if (value1 and value2) or (not value1 and not value2):
|
||||
result = 0
|
||||
|
||||
# Let's say... True < False?
|
||||
elif value1:
|
||||
result = -1
|
||||
|
||||
else:
|
||||
result = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def hex_cmp(value1, value2):
|
||||
"""
|
||||
Compare two STIX "hex" values. This decodes to bytes and compares that.
|
||||
It does *not* do a string compare on the hex representations.
|
||||
|
||||
:param value1: The first HexConstant
|
||||
:param value2: The second HexConstant
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
bytes1 = bytes.fromhex(value1.value)
|
||||
bytes2 = bytes.fromhex(value2.value)
|
||||
|
||||
return generic_cmp(bytes1, bytes2)
|
||||
|
||||
|
||||
def bin_cmp(value1, value2):
|
||||
"""
|
||||
Compare two STIX "binary" values. This decodes to bytes and compares that.
|
||||
It does *not* do a string compare on the base64 representations.
|
||||
|
||||
:param value1: The first BinaryConstant
|
||||
:param value2: The second BinaryConstant
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
bytes1 = base64.standard_b64decode(value1.value)
|
||||
bytes2 = base64.standard_b64decode(value2.value)
|
||||
|
||||
return generic_cmp(bytes1, bytes2)
|
||||
|
||||
|
||||
def list_cmp(value1, value2):
|
||||
"""
|
||||
Compare lists order-insensitively.
|
||||
|
||||
:param value1: The first ListConstant
|
||||
:param value2: The second ListConstant
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
|
||||
# Achieve order-independence by sorting the lists first.
|
||||
sorted_value1 = sorted(
|
||||
value1.value, key=functools.cmp_to_key(constant_cmp)
|
||||
)
|
||||
|
||||
sorted_value2 = sorted(
|
||||
value2.value, key=functools.cmp_to_key(constant_cmp)
|
||||
)
|
||||
|
||||
result = iter_lex_cmp(sorted_value1, sorted_value2, constant_cmp)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_CONSTANT_COMPARATORS = {
|
||||
# We have special handling for ints/floats, so no entries for those AST
|
||||
# classes here. See constant_cmp().
|
||||
StringConstant: generic_constant_cmp,
|
||||
BooleanConstant: bool_cmp,
|
||||
TimestampConstant: generic_constant_cmp,
|
||||
HexConstant: hex_cmp,
|
||||
BinaryConstant: bin_cmp,
|
||||
ListConstant: list_cmp
|
||||
}
|
||||
|
||||
|
||||
def object_path_component_cmp(comp1, comp2):
|
||||
"""
|
||||
Compare a string/int to another string/int; this induces an ordering over
|
||||
all strings and ints. It is used to perform a lexicographical sort on
|
||||
object paths.
|
||||
|
||||
Ints and strings compare as usual to each other; ints compare less than
|
||||
strings.
|
||||
|
||||
:param comp1: An object path component (string or int)
|
||||
:param comp2: An object path component (string or int)
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
|
||||
# both ints or both strings: use builtin comparison operators
|
||||
if (isinstance(comp1, int) and isinstance(comp2, int)) \
|
||||
or (isinstance(comp1, str) and isinstance(comp2, str)):
|
||||
result = generic_cmp(comp1, comp2)
|
||||
|
||||
# one is int, one is string. Let's say ints come before strings.
|
||||
elif isinstance(comp1, int):
|
||||
result = -1
|
||||
|
||||
else:
|
||||
result = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def object_path_to_raw_values(path):
|
||||
"""
|
||||
Converts the given ObjectPath instance to a list of strings and ints.
|
||||
All property names become strings, regardless of whether they're *_ref
|
||||
properties; "*" index steps become that string; and numeric index steps
|
||||
become integers.
|
||||
|
||||
:param path: An ObjectPath instance
|
||||
:return: A generator iterator over the values
|
||||
"""
|
||||
|
||||
for comp in path.property_path:
|
||||
if isinstance(comp, ListObjectPathComponent):
|
||||
yield comp.property_name
|
||||
|
||||
if comp.index == "*" or isinstance(comp.index, int):
|
||||
yield comp.index
|
||||
else:
|
||||
# in case the index is a stringified int; convert to an actual
|
||||
# int
|
||||
yield int(comp.index)
|
||||
|
||||
else:
|
||||
yield comp.property_name
|
||||
|
||||
|
||||
def object_path_cmp(path1, path2):
|
||||
"""
|
||||
Compare two object paths.
|
||||
|
||||
:param path1: The first ObjectPath instance
|
||||
:param path2: The second ObjectPath instance
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
if path1.object_type_name < path2.object_type_name:
|
||||
result = -1
|
||||
|
||||
elif path1.object_type_name > path2.object_type_name:
|
||||
result = 1
|
||||
|
||||
else:
|
||||
# I always thought of key and index path steps as separate. The AST
|
||||
# lumps indices in with the previous key as a single path component.
|
||||
# The following splits the path components into individual comparable
|
||||
# values again. Maybe I should not do this...
|
||||
path_vals1 = object_path_to_raw_values(path1)
|
||||
path_vals2 = object_path_to_raw_values(path2)
|
||||
result = iter_lex_cmp(
|
||||
path_vals1, path_vals2, object_path_component_cmp
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def comparison_operator_cmp(op1, op2):
|
||||
"""
|
||||
Compare two comparison operators.
|
||||
|
||||
:param op1: The first comparison operator (a string)
|
||||
:param op2: The second comparison operator (a string)
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
op1_idx = _COMPARISON_OP_ORDER.index(op1)
|
||||
op2_idx = _COMPARISON_OP_ORDER.index(op2)
|
||||
|
||||
result = generic_cmp(op1_idx, op2_idx)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def constant_cmp(value1, value2):
|
||||
"""
|
||||
Compare two constants.
|
||||
|
||||
:param value1: The first _Constant instance
|
||||
:param value2: The second _Constant instance
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
|
||||
# Special handling for ints/floats: treat them generically as numbers,
|
||||
# ordered before all other types.
|
||||
if isinstance(value1, (IntegerConstant, FloatConstant)) \
|
||||
and isinstance(value2, (IntegerConstant, FloatConstant)):
|
||||
result = generic_constant_cmp(value1, value2)
|
||||
|
||||
elif isinstance(value1, (IntegerConstant, FloatConstant)):
|
||||
result = -1
|
||||
|
||||
elif isinstance(value2, (IntegerConstant, FloatConstant)):
|
||||
result = 1
|
||||
|
||||
else:
|
||||
|
||||
type1 = type(value1)
|
||||
type2 = type(value2)
|
||||
|
||||
type1_idx = _CONSTANT_TYPE_ORDER.index(type1)
|
||||
type2_idx = _CONSTANT_TYPE_ORDER.index(type2)
|
||||
|
||||
result = generic_cmp(type1_idx, type2_idx)
|
||||
if result == 0:
|
||||
# Types are the same; must compare values
|
||||
cmp_func = _CONSTANT_COMPARATORS.get(type1)
|
||||
if not cmp_func:
|
||||
raise TypeError("Don't know how to compare " + type1.__name__)
|
||||
|
||||
result = cmp_func(value1, value2)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_comparison_expression_cmp(expr1, expr2):
|
||||
"""
|
||||
Compare "simple" comparison expressions: those which aren't AND/OR
|
||||
combinations, just <path> <op> <value> comparisons.
|
||||
|
||||
:param expr1: first _ComparisonExpression instance
|
||||
:param expr2: second _ComparisonExpression instance
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
|
||||
result = object_path_cmp(expr1.lhs, expr2.lhs)
|
||||
|
||||
if result == 0:
|
||||
result = comparison_operator_cmp(expr1.operator, expr2.operator)
|
||||
|
||||
if result == 0:
|
||||
# _ComparisonExpression's have a "negated" attribute. Umm...
|
||||
# non-negated < negated?
|
||||
if not expr1.negated and expr2.negated:
|
||||
result = -1
|
||||
elif expr1.negated and not expr2.negated:
|
||||
result = 1
|
||||
|
||||
if result == 0:
|
||||
result = constant_cmp(expr1.rhs, expr2.rhs)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def comparison_expression_cmp(expr1, expr2):
|
||||
"""
|
||||
Compare two comparison expressions. This is sensitive to the order of the
|
||||
expressions' sub-components. To achieve an order-insensitive comparison,
|
||||
the ASTs must be canonically ordered first.
|
||||
|
||||
:param expr1: The first comparison expression
|
||||
:param expr2: The second comparison expression
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
if isinstance(expr1, _ComparisonExpression) \
|
||||
and isinstance(expr2, _ComparisonExpression):
|
||||
result = simple_comparison_expression_cmp(expr1, expr2)
|
||||
|
||||
# One is simple, one is compound. Let's say... simple ones come first?
|
||||
elif isinstance(expr1, _ComparisonExpression):
|
||||
result = -1
|
||||
|
||||
elif isinstance(expr2, _ComparisonExpression):
|
||||
result = 1
|
||||
|
||||
# Both are compound: AND's before OR's?
|
||||
elif isinstance(expr1, AndBooleanExpression) \
|
||||
and isinstance(expr2, OrBooleanExpression):
|
||||
result = -1
|
||||
|
||||
elif isinstance(expr1, OrBooleanExpression) \
|
||||
and isinstance(expr2, AndBooleanExpression):
|
||||
result = 1
|
||||
|
||||
else:
|
||||
# Both compound, same boolean operator: sort according to contents.
|
||||
# This will order according to recursive invocations of this comparator,
|
||||
# on sub-expressions.
|
||||
result = iter_lex_cmp(
|
||||
expr1.operands, expr2.operands, comparison_expression_cmp
|
||||
)
|
||||
|
||||
return result
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
Comparison utilities for STIX pattern observation expressions.
|
||||
"""
|
||||
from stix2.equivalence.patterns.compare import generic_cmp, iter_lex_cmp
|
||||
from stix2.equivalence.patterns.compare.comparison import (
|
||||
comparison_expression_cmp, generic_constant_cmp
|
||||
)
|
||||
from stix2.patterns import (
|
||||
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||
QualifiedObservationExpression, _CompoundObservationExpression,
|
||||
RepeatQualifier, WithinQualifier, StartStopQualifier,
|
||||
FollowedByObservationExpression
|
||||
)
|
||||
|
||||
|
||||
_OBSERVATION_EXPRESSION_TYPE_ORDER = (
|
||||
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||
FollowedByObservationExpression, QualifiedObservationExpression
|
||||
)
|
||||
|
||||
|
||||
_QUALIFIER_TYPE_ORDER = (
|
||||
RepeatQualifier, WithinQualifier, StartStopQualifier
|
||||
)
|
||||
|
||||
|
||||
def repeats_cmp(qual1, qual2):
|
||||
"""
|
||||
Compare REPEATS qualifiers. This orders by repeat count.
|
||||
"""
|
||||
return generic_constant_cmp(qual1.times_to_repeat, qual2.times_to_repeat)
|
||||
|
||||
|
||||
def within_cmp(qual1, qual2):
|
||||
"""
|
||||
Compare WITHIN qualifiers. This orders by number of seconds.
|
||||
"""
|
||||
return generic_constant_cmp(
|
||||
qual1.number_of_seconds, qual2.number_of_seconds
|
||||
)
|
||||
|
||||
|
||||
def startstop_cmp(qual1, qual2):
|
||||
"""
|
||||
Compare START/STOP qualifiers. This lexicographically orders by start time,
|
||||
then stop time.
|
||||
"""
|
||||
return iter_lex_cmp(
|
||||
(qual1.start_time, qual1.stop_time),
|
||||
(qual2.start_time, qual2.stop_time),
|
||||
generic_constant_cmp
|
||||
)
|
||||
|
||||
|
||||
_QUALIFIER_COMPARATORS = {
|
||||
RepeatQualifier: repeats_cmp,
|
||||
WithinQualifier: within_cmp,
|
||||
StartStopQualifier: startstop_cmp
|
||||
}
|
||||
|
||||
|
||||
def observation_expression_cmp(expr1, expr2):
|
||||
"""
|
||||
Compare two observation expression ASTs. This is sensitive to the order of
|
||||
the expressions' sub-components. To achieve an order-insensitive
|
||||
comparison, the ASTs must be canonically ordered first.
|
||||
|
||||
:param expr1: The first observation expression
|
||||
:param expr2: The second observation expression
|
||||
:return: <0, 0, or >0 depending on whether the first arg is less, equal or
|
||||
greater than the second
|
||||
"""
|
||||
type1 = type(expr1)
|
||||
type2 = type(expr2)
|
||||
|
||||
type1_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type1)
|
||||
type2_idx = _OBSERVATION_EXPRESSION_TYPE_ORDER.index(type2)
|
||||
|
||||
if type1_idx != type2_idx:
|
||||
result = generic_cmp(type1_idx, type2_idx)
|
||||
|
||||
# else, both exprs are of same type.
|
||||
|
||||
# If they're simple, use contained comparison expression order
|
||||
elif type1 is ObservationExpression:
|
||||
result = comparison_expression_cmp(
|
||||
expr1.operand, expr2.operand
|
||||
)
|
||||
|
||||
elif isinstance(expr1, _CompoundObservationExpression):
|
||||
# Both compound, and of same type (and/or/followedby): sort according
|
||||
# to contents.
|
||||
result = iter_lex_cmp(
|
||||
expr1.operands, expr2.operands, observation_expression_cmp
|
||||
)
|
||||
|
||||
else: # QualifiedObservationExpression
|
||||
# Both qualified. Check qualifiers first; if they are the same,
|
||||
# use order of the qualified expressions.
|
||||
qual1_type = type(expr1.qualifier)
|
||||
qual2_type = type(expr2.qualifier)
|
||||
|
||||
qual1_type_idx = _QUALIFIER_TYPE_ORDER.index(qual1_type)
|
||||
qual2_type_idx = _QUALIFIER_TYPE_ORDER.index(qual2_type)
|
||||
|
||||
result = generic_cmp(qual1_type_idx, qual2_type_idx)
|
||||
|
||||
if result == 0:
|
||||
# Same qualifier type; compare qualifier details
|
||||
qual_cmp = _QUALIFIER_COMPARATORS.get(qual1_type)
|
||||
if qual_cmp:
|
||||
result = qual_cmp(expr1.qualifier, expr2.qualifier)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Can't compare qualifier type: " + qual1_type.__name__
|
||||
)
|
||||
|
||||
if result == 0:
|
||||
# Same qualifier type and details; use qualified expression order
|
||||
result = observation_expression_cmp(
|
||||
expr1.observation_expression, expr2.observation_expression
|
||||
)
|
||||
|
||||
return result
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
Generic AST transformation classes.
|
||||
"""
|
||||
|
||||
class Transformer:
|
||||
"""
|
||||
Base class for AST transformers.
|
||||
"""
|
||||
def transform(self, ast):
|
||||
"""
|
||||
Transform the given AST and return the resulting AST.
|
||||
|
||||
:param ast: The AST to transform
|
||||
:return: A 2-tuple: the transformed AST and a boolean indicating whether
|
||||
the transformation actually changed anything. The change detection
|
||||
is useful in situations where a transformation needs to be repeated
|
||||
until the AST stops changing.
|
||||
"""
|
||||
raise NotImplemented("transform")
|
||||
|
||||
|
||||
class ChainTransformer(Transformer):
|
||||
"""
|
||||
A composite transformer which consists of a sequence of sub-transformers.
|
||||
Applying this transformer applies all sub-transformers in sequence, as
|
||||
a group.
|
||||
"""
|
||||
def __init__(self, *transformers):
|
||||
self.__transformers = transformers
|
||||
|
||||
def transform(self, ast):
|
||||
changed = False
|
||||
for transformer in self.__transformers:
|
||||
ast, this_changed = transformer.transform(ast)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
return ast, changed
|
||||
|
||||
|
||||
class SettleTransformer(Transformer):
|
||||
"""
|
||||
A transformer that repeatedly performs a transformation until that
|
||||
transformation no longer changes the AST. I.e. the AST has "settled".
|
||||
"""
|
||||
def __init__(self, transform):
|
||||
self.__transformer = transform
|
||||
|
||||
def transform(self, ast):
|
||||
changed = False
|
||||
ast, this_changed = self.__transformer.transform(ast)
|
||||
while this_changed:
|
||||
changed = True
|
||||
ast, this_changed = self.__transformer.transform(ast)
|
||||
|
||||
return ast, changed
|
|
@ -0,0 +1,331 @@
|
|||
"""
|
||||
Transformation utilities for STIX pattern comparison expressions.
|
||||
"""
|
||||
import functools
|
||||
import itertools
|
||||
from stix2.equivalence.patterns.transform import Transformer
|
||||
from stix2.patterns import (
|
||||
_BooleanExpression, _ComparisonExpression, AndBooleanExpression,
|
||||
OrBooleanExpression, ParentheticalExpression
|
||||
)
|
||||
from stix2.equivalence.patterns.compare.comparison import (
|
||||
comparison_expression_cmp
|
||||
)
|
||||
from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in
|
||||
|
||||
|
||||
def _dupe_ast(ast):
|
||||
"""
|
||||
Create a duplicate of the given AST.
|
||||
|
||||
Note: the comparison expression "leaves", i.e. simple <path> <op> <value>
|
||||
comparisons are currently not duplicated. I don't think it's necessary as
|
||||
of this writing; they are never changed. But revisit this if/when
|
||||
necessary.
|
||||
|
||||
:param ast: The AST to duplicate
|
||||
:return: The duplicate AST
|
||||
"""
|
||||
if isinstance(ast, AndBooleanExpression):
|
||||
result = AndBooleanExpression([
|
||||
_dupe_ast(operand) for operand in ast.operands
|
||||
])
|
||||
|
||||
elif isinstance(ast, OrBooleanExpression):
|
||||
result = OrBooleanExpression([
|
||||
_dupe_ast(operand) for operand in ast.operands
|
||||
])
|
||||
|
||||
elif isinstance(ast, _ComparisonExpression):
|
||||
# Change this to create a dupe, if we ever need to change simple
|
||||
# comparison expressions as part of canonicalization.
|
||||
result = ast
|
||||
|
||||
else:
|
||||
raise TypeError("Can't duplicate " + type(ast).__name__)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ComparisonExpressionTransformer(Transformer):
|
||||
"""
|
||||
Transformer base class with special support for transforming comparison
|
||||
expressions. The transform method implemented here performs a bottom-up
|
||||
in-place transformation, with support for some comparison
|
||||
expression-specific callbacks.
|
||||
|
||||
Specifically, subclasses can implement methods:
|
||||
"transform_or" for OR nodes
|
||||
"transform_and" for AND nodes
|
||||
"transform_default" for both types of nodes
|
||||
|
||||
"transform_default" is a fallback, if a type-specific callback is not
|
||||
found. The default implementation does nothing to the AST. The
|
||||
type-specific callbacks are preferred over the default, if both exist.
|
||||
|
||||
In all cases, the callbacks are called with an AST for a subtree rooted at
|
||||
the appropriate node type, where the subtree's children have already been
|
||||
transformed. They must return the same thing as the base transform()
|
||||
method: a 2-tuple with the transformed AST and a boolean for change
|
||||
detection. See doc for the superclass' method.
|
||||
|
||||
This process currently silently drops parenthetical nodes, and "leaf"
|
||||
comparison expression nodes are left unchanged.
|
||||
"""
|
||||
|
||||
def transform(self, ast):
|
||||
if isinstance(ast, _BooleanExpression):
|
||||
changed = False
|
||||
for i, operand in enumerate(ast.operands):
|
||||
operand_result, this_changed = self.transform(operand)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
ast.operands[i] = operand_result
|
||||
|
||||
result, this_changed = self.__dispatch_transform(ast)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
elif isinstance(ast, _ComparisonExpression):
|
||||
# Terminates recursion; we don't change these nodes
|
||||
result = ast
|
||||
changed = False
|
||||
|
||||
elif isinstance(ast, ParentheticalExpression):
|
||||
# Drop these
|
||||
result, changed = self.transform(ast.expression)
|
||||
|
||||
else:
|
||||
raise TypeError("Not a comparison expression: " + str(ast))
|
||||
|
||||
return result, changed
|
||||
|
||||
def __dispatch_transform(self, ast):
|
||||
"""
|
||||
Invoke a transformer callback method based on the given ast root node
|
||||
type.
|
||||
|
||||
:param ast: The AST
|
||||
:return: The callback's result
|
||||
"""
|
||||
|
||||
if isinstance(ast, AndBooleanExpression):
|
||||
meth = getattr(self, "transform_and", self.transform_default)
|
||||
|
||||
elif isinstance(ast, OrBooleanExpression):
|
||||
meth = getattr(self, "transform_or", self.transform_default)
|
||||
|
||||
else:
|
||||
meth = self.transform_default
|
||||
|
||||
return meth(ast)
|
||||
|
||||
def transform_default(self, ast):
|
||||
"""
|
||||
Override to handle transforming AST nodes which don't have a more
|
||||
specific method implemented.
|
||||
"""
|
||||
return ast, False
|
||||
|
||||
|
||||
class OrderDedupeTransformer(
|
||||
ComparisonExpressionTransformer
|
||||
):
|
||||
"""
|
||||
Canonically order the children of all nodes in the AST. Because the
|
||||
deduping algorithm is based on sorted data, this transformation also does
|
||||
deduping.
|
||||
|
||||
E.g.:
|
||||
A and A => A
|
||||
A or A => A
|
||||
"""
|
||||
|
||||
def transform_default(self, ast):
|
||||
"""
|
||||
Sort/dedupe children. AND and OR can be treated identically.
|
||||
|
||||
:param ast: The comparison expression AST
|
||||
:return: The same AST node, but with sorted children
|
||||
"""
|
||||
sorted_children = sorted(
|
||||
ast.operands, key=functools.cmp_to_key(comparison_expression_cmp)
|
||||
)
|
||||
|
||||
deduped_children = [
|
||||
# Apparently when using a key function, groupby()'s "keys" are the
|
||||
# key wrappers, not actual sequence values. Obviously we don't
|
||||
# need key wrappers in our ASTs!
|
||||
k.obj for k, _ in itertools.groupby(
|
||||
sorted_children, key=functools.cmp_to_key(
|
||||
comparison_expression_cmp
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
changed = iter_lex_cmp(
|
||||
ast.operands, deduped_children, comparison_expression_cmp
|
||||
) != 0
|
||||
|
||||
ast.operands = deduped_children
|
||||
|
||||
return ast, changed
|
||||
|
||||
|
||||
class FlattenTransformer(ComparisonExpressionTransformer):
|
||||
"""
|
||||
Flatten all nodes of the AST. E.g.:
|
||||
|
||||
A and (B and C) => A and B and C
|
||||
A or (B or C) => A or B or C
|
||||
(A) => A
|
||||
"""
|
||||
|
||||
def transform_default(self, ast):
|
||||
"""
|
||||
Flatten children. AND and OR can be treated mostly identically. The
|
||||
little difference is that we can absorb AND children if we're an AND
|
||||
ourselves; and OR for OR.
|
||||
|
||||
:param ast: The comparison expression AST
|
||||
:return: The same AST node, but with flattened children
|
||||
"""
|
||||
|
||||
if isinstance(ast, _BooleanExpression) and len(ast.operands) == 1:
|
||||
# Replace an AND/OR with one child, with the child itself.
|
||||
ast = ast.operands[0]
|
||||
changed = True
|
||||
|
||||
else:
|
||||
flat_operands = []
|
||||
changed = False
|
||||
for operand in ast.operands:
|
||||
if isinstance(operand, _BooleanExpression) \
|
||||
and ast.operator == operand.operator:
|
||||
flat_operands.extend(operand.operands)
|
||||
changed = True
|
||||
|
||||
else:
|
||||
flat_operands.append(operand)
|
||||
|
||||
ast.operands = flat_operands
|
||||
|
||||
return ast, changed
|
||||
|
||||
|
||||
class AbsorptionTransformer(
|
||||
ComparisonExpressionTransformer
|
||||
):
|
||||
"""
|
||||
Applies boolean "absorption" rules for AST simplification. E.g.:
|
||||
|
||||
A and (A or B) = A
|
||||
A or (A and B) = A
|
||||
"""
|
||||
|
||||
def transform_default(self, ast):
|
||||
|
||||
changed = False
|
||||
if isinstance(ast, _BooleanExpression):
|
||||
secondary_op = "AND" if ast.operator == "OR" else "OR"
|
||||
|
||||
to_delete = set()
|
||||
|
||||
# Check i (child1) against j to see if we can delete j.
|
||||
for i, child1 in enumerate(ast.operands):
|
||||
if i in to_delete:
|
||||
continue
|
||||
|
||||
for j, child2 in enumerate(ast.operands):
|
||||
if i == j or j in to_delete:
|
||||
continue
|
||||
|
||||
# We're checking if child1 is contained in child2, so
|
||||
# child2 has to be a compound object, not just a simple
|
||||
# comparison expression. We also require the right operator
|
||||
# for child2: "AND" if ast is "OR" and vice versa.
|
||||
if not isinstance(child2, _BooleanExpression) \
|
||||
or child2.operator != secondary_op:
|
||||
continue
|
||||
|
||||
# The simple check: is child1 contained in child2?
|
||||
if iter_in(
|
||||
child1, child2.operands, comparison_expression_cmp
|
||||
):
|
||||
to_delete.add(j)
|
||||
|
||||
# A more complicated check: does child1 occur in child2
|
||||
# in a "flattened" form?
|
||||
elif child1.operator == child2.operator:
|
||||
if all(
|
||||
iter_in(
|
||||
child1_operand, child2.operands,
|
||||
comparison_expression_cmp
|
||||
)
|
||||
for child1_operand in child1.operands
|
||||
):
|
||||
to_delete.add(j)
|
||||
|
||||
if to_delete:
|
||||
changed = True
|
||||
|
||||
for i in reversed(sorted(to_delete)):
|
||||
del ast.operands[i]
|
||||
|
||||
return ast, changed
|
||||
|
||||
|
||||
class DNFTransformer(ComparisonExpressionTransformer):
|
||||
"""
|
||||
Convert a comparison expression AST to DNF. E.g.:
|
||||
|
||||
A and (B or C) => (A and B) or (A and C)
|
||||
"""
|
||||
def transform_and(self, ast):
|
||||
or_children = []
|
||||
other_children = []
|
||||
changed = False
|
||||
|
||||
# Sort AND children into two piles: the ORs and everything else
|
||||
for child in ast.operands:
|
||||
if isinstance(child, _BooleanExpression) and child.operator == "OR":
|
||||
# Need a list of operand lists, so we can compute the
|
||||
# product below.
|
||||
or_children.append(child.operands)
|
||||
else:
|
||||
other_children.append(child)
|
||||
|
||||
if or_children:
|
||||
distributed_children = [
|
||||
AndBooleanExpression([
|
||||
# Make dupes: distribution implies adding repetition, and
|
||||
# we should ensure each repetition is independent of the
|
||||
# others.
|
||||
_dupe_ast(sub_ast) for sub_ast in itertools.chain(
|
||||
other_children, prod_seq
|
||||
)
|
||||
])
|
||||
for prod_seq in itertools.product(*or_children)
|
||||
]
|
||||
|
||||
# Need to recursively continue to distribute AND over OR in
|
||||
# any of our new sub-expressions which need it. This causes
|
||||
# more downward recursion in the midst of this bottom-up transform.
|
||||
# It's not good for performance. I wonder if a top-down
|
||||
# transformation algorithm would make more sense in this phase?
|
||||
# But then we'd be using two different algorithms for the same
|
||||
# thing... Maybe this transform should be completely top-down
|
||||
# (no bottom-up component at all)?
|
||||
distributed_children = [
|
||||
self.transform(child)[0] for child in distributed_children
|
||||
]
|
||||
|
||||
result = OrBooleanExpression(distributed_children)
|
||||
changed = True
|
||||
|
||||
else:
|
||||
# No AND-over-OR; nothing to do
|
||||
result = ast
|
||||
|
||||
return result, changed
|
|
@ -0,0 +1,486 @@
|
|||
"""
|
||||
Transformation utilities for STIX pattern observation expressions.
|
||||
"""
|
||||
import functools
|
||||
import itertools
|
||||
from stix2.patterns import (
|
||||
ObservationExpression, AndObservationExpression, OrObservationExpression,
|
||||
QualifiedObservationExpression, _CompoundObservationExpression,
|
||||
ParentheticalExpression, FollowedByObservationExpression
|
||||
)
|
||||
from stix2.equivalence.patterns.transform import (
|
||||
ChainTransformer, SettleTransformer, Transformer
|
||||
)
|
||||
from stix2.equivalence.patterns.transform.comparison import (
|
||||
FlattenTransformer as CFlattenTransformer,
|
||||
OrderDedupeTransformer as COrderDedupeTransformer,
|
||||
AbsorptionTransformer as CAbsorptionTransformer,
|
||||
DNFTransformer as CDNFTransformer
|
||||
)
|
||||
from stix2.equivalence.patterns.compare import iter_lex_cmp, iter_in
|
||||
from stix2.equivalence.patterns.compare.observation import observation_expression_cmp
|
||||
|
||||
|
||||
def _dupe_ast(ast):
|
||||
"""
|
||||
Create a duplicate of the given AST. The AST root must be an observation
|
||||
expression of some kind (AND/OR/qualified, etc).
|
||||
|
||||
Note: the observation expression "leaves", i.e. simple square-bracket
|
||||
observation expressions are currently not duplicated. I don't think it's
|
||||
necessary as of this writing. But revisit this if/when necessary.
|
||||
|
||||
:param ast: The AST to duplicate
|
||||
:return: The duplicate AST
|
||||
"""
|
||||
if isinstance(ast, AndObservationExpression):
|
||||
result = AndObservationExpression([
|
||||
_dupe_ast(child) for child in ast.operands
|
||||
])
|
||||
|
||||
elif isinstance(ast, OrObservationExpression):
|
||||
result = OrObservationExpression([
|
||||
_dupe_ast(child) for child in ast.operands
|
||||
])
|
||||
|
||||
elif isinstance(ast, FollowedByObservationExpression):
|
||||
result = FollowedByObservationExpression([
|
||||
_dupe_ast(child) for child in ast.operands
|
||||
])
|
||||
|
||||
elif isinstance(ast, QualifiedObservationExpression):
|
||||
# Don't need to dupe the qualifier object at this point
|
||||
result = QualifiedObservationExpression(
|
||||
_dupe_ast(ast.observation_expression), ast.qualifier
|
||||
)
|
||||
|
||||
elif isinstance(ast, ObservationExpression):
|
||||
result = ast
|
||||
|
||||
else:
|
||||
raise TypeError("Can't duplicate " + type(ast).__name__)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ObservationExpressionTransformer(Transformer):
|
||||
"""
|
||||
Transformer base class with special support for transforming observation
|
||||
expressions. The transform method implemented here performs a bottom-up
|
||||
in-place transformation, with support for some observation
|
||||
expression-specific callbacks. It recurses down as far as the "leaf node"
|
||||
observation expressions; it does not go inside of them, to the individual
|
||||
components of a comparison expression.
|
||||
|
||||
Specifically, subclasses can implement methods:
|
||||
"transform_or" for OR nodes
|
||||
"transform_and" for AND nodes
|
||||
"transform_followedby" for FOLLOWEDBY nodes
|
||||
"transform_qualified" for qualified nodes (all qualifier types)
|
||||
"transform_observation" for "leaf" observation expression nodes
|
||||
"transform_default" for all types of nodes
|
||||
|
||||
"transform_default" is a fallback, if a type-specific callback is not
|
||||
found. The default implementation does nothing to the AST. The
|
||||
type-specific callbacks are preferred over the default, if both exist.
|
||||
|
||||
In all cases, the callbacks are called with an AST for a subtree rooted at
|
||||
the appropriate node type, where the AST's children have already been
|
||||
transformed. They must return the same thing as the base transform()
|
||||
method: a 2-tuple with the transformed AST and a boolean for change
|
||||
detection. See doc for the superclass' method.
|
||||
|
||||
This process currently silently drops parenthetical nodes.
|
||||
"""
|
||||
|
||||
# Determines how AST node types map to callback method names
|
||||
_DISPATCH_NAME_MAP = {
|
||||
ObservationExpression: "observation",
|
||||
AndObservationExpression: "and",
|
||||
OrObservationExpression: "or",
|
||||
FollowedByObservationExpression: "followedby",
|
||||
QualifiedObservationExpression: "qualified"
|
||||
}
|
||||
|
||||
def transform(self, ast):
|
||||
|
||||
changed = False
|
||||
if isinstance(ast, ObservationExpression):
|
||||
# A "leaf node" for observation expressions. We don't recurse into
|
||||
# these.
|
||||
result, this_changed = self.__dispatch_transform(ast)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
elif isinstance(ast, _CompoundObservationExpression):
|
||||
for i, operand in enumerate(ast.operands):
|
||||
result, this_changed = self.transform(operand)
|
||||
if this_changed:
|
||||
ast.operands[i] = result
|
||||
changed = True
|
||||
|
||||
result, this_changed = self.__dispatch_transform(ast)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
elif isinstance(ast, QualifiedObservationExpression):
|
||||
# I don't think we need to process/transform the qualifier by
|
||||
# itself, do we?
|
||||
result, this_changed = self.transform(ast.observation_expression)
|
||||
if this_changed:
|
||||
ast.observation_expression = result
|
||||
changed = True
|
||||
|
||||
result, this_changed = self.__dispatch_transform(ast)
|
||||
if this_changed:
|
||||
changed = True
|
||||
|
||||
elif isinstance(ast, ParentheticalExpression):
|
||||
result, _ = self.transform(ast.expression)
|
||||
# Dropping a node is a change, right?
|
||||
changed = True
|
||||
|
||||
else:
|
||||
raise TypeError("Not an observation expression: {}: {}".format(
|
||||
type(ast).__name__, str(ast)
|
||||
))
|
||||
|
||||
return result, changed
|
||||
|
||||
def __dispatch_transform(self, ast):
|
||||
"""
|
||||
Invoke a transformer callback method based on the given ast root node
|
||||
type.
|
||||
|
||||
:param ast: The AST
|
||||
:return: The callback's result
|
||||
"""
|
||||
|
||||
dispatch_name = self._DISPATCH_NAME_MAP.get(type(ast))
|
||||
if dispatch_name:
|
||||
meth_name = "transform_" + dispatch_name
|
||||
meth = getattr(self, meth_name, self.transform_default)
|
||||
else:
|
||||
meth = self.transform_default
|
||||
|
||||
return meth(ast)
|
||||
|
||||
def transform_default(self, ast):
|
||||
return ast, False
|
||||
|
||||
|
||||
class FlattenTransformer(ObservationExpressionTransformer):
|
||||
"""
|
||||
Flatten an observation expression AST. E.g.:
|
||||
|
||||
A and (B and C) => A and B and C
|
||||
A or (B or C) => A or B or C
|
||||
A followedby (B followedby C) => A followedby B followedby C
|
||||
(A) => A
|
||||
"""
|
||||
|
||||
def __transform(self, ast):
|
||||
|
||||
changed = False
|
||||
|
||||
if len(ast.operands) == 1:
|
||||
# Replace an AND/OR/FOLLOWEDBY with one child, with the child
|
||||
# itself.
|
||||
result = ast.operands[0]
|
||||
changed = True
|
||||
|
||||
else:
|
||||
flat_children = []
|
||||
for operand in ast.operands:
|
||||
if isinstance(operand, _CompoundObservationExpression) \
|
||||
and ast.operator == operand.operator:
|
||||
flat_children.extend(operand.operands)
|
||||
changed = True
|
||||
else:
|
||||
flat_children.append(operand)
|
||||
|
||||
ast.operands = flat_children
|
||||
result = ast
|
||||
|
||||
return result, changed
|
||||
|
||||
def transform_and(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
def transform_or(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
def transform_followedby(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
|
||||
class OrderDedupeTransformer(
|
||||
ObservationExpressionTransformer
|
||||
):
|
||||
"""
|
||||
Canonically order AND/OR expressions, and dedupe ORs. E.g.:
|
||||
|
||||
A or A => A
|
||||
B or A => A or B
|
||||
B and A => A and B
|
||||
"""
|
||||
|
||||
def __transform(self, ast):
|
||||
sorted_children = sorted(
|
||||
ast.operands, key=functools.cmp_to_key(observation_expression_cmp)
|
||||
)
|
||||
|
||||
# Deduping only applies to ORs
|
||||
if ast.operator == "OR":
|
||||
deduped_children = [
|
||||
key.obj for key, _ in itertools.groupby(
|
||||
sorted_children, key=functools.cmp_to_key(
|
||||
observation_expression_cmp
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
deduped_children = sorted_children
|
||||
|
||||
changed = iter_lex_cmp(
|
||||
ast.operands, deduped_children, observation_expression_cmp
|
||||
) != 0
|
||||
|
||||
ast.operands = deduped_children
|
||||
|
||||
return ast, changed
|
||||
|
||||
def transform_and(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
def transform_or(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
|
||||
class AbsorptionTransformer(
|
||||
ObservationExpressionTransformer
|
||||
):
|
||||
"""
|
||||
Applies boolean "absorption" rules for observation expressions, for AST
|
||||
simplification:
|
||||
|
||||
A or (A and B) = A
|
||||
A or (A followedby B) = A
|
||||
|
||||
Other variants do not hold for observation expressions.
|
||||
"""
|
||||
|
||||
def __is_contained_and(self, exprs_containee, exprs_container):
|
||||
"""
|
||||
Determine whether the "containee" expressions are contained in the
|
||||
"container" expressions, with AND semantics (order-independent but need
|
||||
distinct bindings). For example (with containee on left and container
|
||||
on right):
|
||||
|
||||
(A and A and B) or (A and B and C)
|
||||
|
||||
In the above, all of the lhs vars have a counterpart in the rhs, but
|
||||
there are two A's on the left and only one on the right. Therefore,
|
||||
the right does not "contain" the left. You would need two A's on the
|
||||
right.
|
||||
|
||||
:param exprs_containee: The expressions we want to check for containment
|
||||
:param exprs_container: The expressions acting as the "container"
|
||||
:return: True if the containee is contained in the container; False if
|
||||
not
|
||||
"""
|
||||
|
||||
# make our own list we are free to manipulate without affecting the
|
||||
# function args.
|
||||
container = list(exprs_container)
|
||||
|
||||
result = True
|
||||
for ee in exprs_containee:
|
||||
for i, er in enumerate(container):
|
||||
if observation_expression_cmp(ee, er) == 0:
|
||||
# Found a match in the container; delete it so we never try
|
||||
# to match a container expr to two different containee
|
||||
# expressions.
|
||||
del container[i]
|
||||
break
|
||||
else:
|
||||
result = False
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
def __is_contained_followedby(self, exprs_containee, exprs_container):
|
||||
"""
|
||||
Determine whether the "containee" expressions are contained in the
|
||||
"container" expressions, with FOLLOWEDBY semantics (order-sensitive and
|
||||
need distinct bindings). For example (with containee on left and
|
||||
container on right):
|
||||
|
||||
(A followedby B) or (B followedby A)
|
||||
|
||||
In the above, all of the lhs vars have a counterpart in the rhs, but
|
||||
the vars on the right are not in the same order. Therefore, the right
|
||||
does not "contain" the left. The container vars don't have to be
|
||||
contiguous though. E.g. in:
|
||||
|
||||
(A followedby B) or (D followedby A followedby C followedby B)
|
||||
|
||||
in the container (rhs), B follows A, so it "contains" the lhs even
|
||||
though there is other stuff mixed in.
|
||||
|
||||
:param exprs_containee: The expressions we want to check for containment
|
||||
:param exprs_container: The expressions acting as the "container"
|
||||
:return: True if the containee is contained in the container; False if
|
||||
not
|
||||
"""
|
||||
|
||||
ee_iter = iter(exprs_containee)
|
||||
er_iter = iter(exprs_container)
|
||||
|
||||
result = True
|
||||
while True:
|
||||
ee = next(ee_iter, None)
|
||||
if not ee:
|
||||
break
|
||||
|
||||
while True:
|
||||
er = next(er_iter, None)
|
||||
if er:
|
||||
if observation_expression_cmp(ee, er) == 0:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if not er:
|
||||
result = False
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
def transform_or(self, ast):
|
||||
changed = False
|
||||
to_delete = set()
|
||||
for i, child1 in enumerate(ast.operands):
|
||||
if i in to_delete:
|
||||
continue
|
||||
|
||||
# The simplification doesn't work across qualifiers
|
||||
if isinstance(child1, QualifiedObservationExpression):
|
||||
continue
|
||||
|
||||
for j, child2 in enumerate(ast.operands):
|
||||
if i == j or j in to_delete:
|
||||
continue
|
||||
|
||||
if isinstance(
|
||||
child2, (
|
||||
AndObservationExpression,
|
||||
FollowedByObservationExpression
|
||||
)
|
||||
):
|
||||
# The simple check: is child1 contained in child2?
|
||||
if iter_in(
|
||||
child1, child2.operands, observation_expression_cmp
|
||||
):
|
||||
to_delete.add(j)
|
||||
|
||||
# A more complicated check: does child1 occur in child2
|
||||
# in a "flattened" form?
|
||||
elif type(child1) is type(child2):
|
||||
if isinstance(child1, AndObservationExpression):
|
||||
can_simplify = self.__is_contained_and(
|
||||
child1.operands, child2.operands
|
||||
)
|
||||
else: # child1 and 2 are followedby nodes
|
||||
can_simplify = self.__is_contained_followedby(
|
||||
child1.operands, child2.operands
|
||||
)
|
||||
|
||||
if can_simplify:
|
||||
to_delete.add(j)
|
||||
|
||||
if to_delete:
|
||||
changed = True
|
||||
|
||||
for i in reversed(sorted(to_delete)):
|
||||
del ast.operands[i]
|
||||
|
||||
return ast, changed
|
||||
|
||||
|
||||
class DNFTransformer(ObservationExpressionTransformer):
|
||||
"""
|
||||
Transform an observation expression to DNF. This will distribute AND and
|
||||
FOLLOWEDBY over OR:
|
||||
|
||||
A and (B or C) => (A and B) or (A and C)
|
||||
A followedby (B or C) => (A followedby B) or (A followedby C)
|
||||
"""
|
||||
|
||||
def __transform(self, ast):
|
||||
|
||||
root_type = type(ast) # will be AST class for AND or FOLLOWEDBY
|
||||
changed = False
|
||||
or_children = []
|
||||
other_children = []
|
||||
for child in ast.operands:
|
||||
if isinstance(child, OrObservationExpression):
|
||||
or_children.append(child.operands)
|
||||
else:
|
||||
other_children.append(child)
|
||||
|
||||
if or_children:
|
||||
distributed_children = [
|
||||
root_type([
|
||||
_dupe_ast(sub_ast) for sub_ast in itertools.chain(
|
||||
other_children, prod_seq
|
||||
)
|
||||
])
|
||||
for prod_seq in itertools.product(*or_children)
|
||||
]
|
||||
|
||||
# Need to recursively continue to distribute AND/FOLLOWEDBY over OR
|
||||
# in any of our new sub-expressions which need it.
|
||||
distributed_children = [
|
||||
self.transform(child)[0] for child in distributed_children
|
||||
]
|
||||
|
||||
result = OrObservationExpression(distributed_children)
|
||||
changed = True
|
||||
|
||||
else:
|
||||
result = ast
|
||||
|
||||
return result, changed
|
||||
|
||||
def transform_and(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
def transform_followedby(self, ast):
|
||||
return self.__transform(ast)
|
||||
|
||||
|
||||
class CanonicalizeComparisonExpressionsTransformer(
|
||||
ObservationExpressionTransformer
|
||||
):
|
||||
"""
|
||||
Canonicalize all comparison expressions.
|
||||
"""
|
||||
def __init__(self):
|
||||
comp_flatten = CFlattenTransformer()
|
||||
comp_order = COrderDedupeTransformer()
|
||||
comp_absorb = CAbsorptionTransformer()
|
||||
simplify = ChainTransformer(comp_flatten, comp_order, comp_absorb)
|
||||
settle_simplify = SettleTransformer(simplify)
|
||||
|
||||
comp_dnf = CDNFTransformer()
|
||||
self.__comp_canonicalize = ChainTransformer(
|
||||
settle_simplify, comp_dnf, settle_simplify
|
||||
)
|
||||
|
||||
def transform_observation(self, ast):
|
||||
comp_expr = ast.operand
|
||||
canon_comp_expr, changed = self.__comp_canonicalize.transform(comp_expr)
|
||||
ast.operand = canon_comp_expr
|
||||
|
||||
return ast, changed
|
Loading…
Reference in New Issue