diff --git a/stix2/equivalence/pattern/transform/comparison.py b/stix2/equivalence/pattern/transform/comparison.py index 93a80e4..9da91ef 100644 --- a/stix2/equivalence/pattern/transform/comparison.py +++ b/stix2/equivalence/pattern/transform/comparison.py @@ -13,8 +13,8 @@ from stix2.equivalence.pattern.transform.specials import ( ipv4_addr, ipv6_addr, windows_reg_key, ) from stix2.patterns import ( - AndBooleanExpression, OrBooleanExpression, ParentheticalExpression, - _BooleanExpression, _ComparisonExpression, + AndBooleanExpression, ObjectPath, OrBooleanExpression, + ParentheticalExpression, _BooleanExpression, _ComparisonExpression, ) @@ -22,12 +22,6 @@ def _dupe_ast(ast): """ Create a duplicate of the given AST. - Note: - The comparison expression "leaves", i.e. simple - comparisons are currently not duplicated. I don't think it's necessary - as of this writing; they are never changed. But revisit this if/when - necessary. - Args: ast: The AST to duplicate @@ -45,9 +39,13 @@ def _dupe_ast(ast): ]) elif isinstance(ast, _ComparisonExpression): - # Change this to create a dupe, if we ever need to change simple - # comparison expressions as part of normalization. - result = ast + # Maybe we go as far as duping the ObjectPath object too? + new_object_path = ObjectPath( + ast.lhs.object_type_name, ast.lhs.property_path, + ) + result = _ComparisonExpression( + ast.operator, new_object_path, ast.rhs, ast.negated, + ) else: raise TypeError("Can't duplicate " + type(ast).__name__) @@ -333,17 +331,33 @@ class DNFTransformer(ComparisonExpressionTransformer): other_children.append(child) if or_children: - distributed_children = [ - AndBooleanExpression([ - # Make dupes: distribution implies adding repetition, and - # we should ensure each repetition is independent of the - # others. - _dupe_ast(sub_ast) for sub_ast in itertools.chain( - other_children, prod_seq, - ) - ]) + distributed_and_arg_sets = ( + itertools.chain(other_children, prod_seq) for prod_seq in itertools.product(*or_children) - ] + ) + + # The AST implementation will error if AND boolean comparison + # operands have no common SCO types. We need to handle that here. + # The following will drop AND's with no common SCO types, which is + # harmless (since they're impossible patterns and couldn't match + # anything anyway). It also acts as a nice simplification of the + # pattern. If the original AND node was legal (operands had at + # least one SCO type in common), it is guaranteed that there will + # be at least one legal distributed AND node (distributed_children + # below will not wind up empty). + distributed_children = [] + for and_arg_set in distributed_and_arg_sets: + try: + and_node = AndBooleanExpression( + # Make dupes: distribution implies adding repetition, + # and we should ensure each repetition is independent + # of the others. + _dupe_ast(arg) for arg in and_arg_set + ) + except ValueError: + pass + else: + distributed_children.append(and_node) # Need to recursively continue to distribute AND over OR in # any of our new sub-expressions which need it. This causes diff --git a/stix2/patterns.py b/stix2/patterns.py index a718bf7..c53a83f 100644 --- a/stix2/patterns.py +++ b/stix2/patterns.py @@ -505,7 +505,7 @@ class _BooleanExpression(_PatternExpression): def __init__(self, operator, operands): self.operator = operator self.operands = list(operands) - for arg in operands: + for arg in self.operands: if not hasattr(self, "root_types"): self.root_types = arg.root_types elif operator == "AND": diff --git a/stix2/test/test_pattern_equivalence.py b/stix2/test/test_pattern_equivalence.py index cebb9e7..a33caec 100644 --- a/stix2/test/test_pattern_equivalence.py +++ b/stix2/test/test_pattern_equivalence.py @@ -384,6 +384,15 @@ def test_comp_absorb_equivalent(patt1, patt2): "[a:b=1 AND (a:b=2 AND (a:b=3 OR a:b=4))]", "[(a:b=1 AND a:b=2 AND a:b=3) OR (a:b=1 AND a:b=2 AND a:b=4)]", ), + # Some tests with different SCO types + ( + "[(a:b=1 OR b:c=1) AND (b:d=1 OR c:d=1)]", + "[b:c=1 AND b:d=1]", + ), + ( + "[(a:b=1 OR b:c=1) AND (b:d=1 OR c:d=1)]", + "[(z:y=1 OR b:c=1) AND (b:d=1 OR x:w=1 OR v:u=1)]", + ), ], ) def test_comp_dnf_equivalent(patt1, patt2):