diff --git a/stix2/patterns.py b/stix2/patterns.py index f1472cd..ce07637 100644 --- a/stix2/patterns.py +++ b/stix2/patterns.py @@ -366,7 +366,7 @@ class _ComparisonExpression(_PatternExpression): else: self.rhs = make_constant(rhs) self.negated = negated - self.root_type = self.lhs.object_type_name + self.root_types = {self.lhs.object_type_name} def __str__(self): if self.negated: @@ -506,15 +506,17 @@ class _BooleanExpression(_PatternExpression): """ def __init__(self, operator, operands): self.operator = operator - self.operands = [] + self.operands = list(operands) for arg in operands: - if not hasattr(self, "root_type"): - self.root_type = arg.root_type - elif self.root_type and (self.root_type != arg.root_type) and operator == "AND": - raise ValueError("All operands to an 'AND' expression must have the same object type") - elif self.root_type and (self.root_type != arg.root_type): - self.root_type = None - self.operands.append(arg) + if not hasattr(self, "root_types"): + self.root_types = arg.root_types + elif operator == "AND": + self.root_types &= arg.root_types + else: + self.root_types |= arg.root_types + + if not self.root_types: + raise ValueError("All operands to an 'AND' expression must be satisfiable with the same object type") def __str__(self): sub_exprs = [] @@ -613,8 +615,8 @@ class ParentheticalExpression(_PatternExpression): """ def __init__(self, exp): self.expression = exp - if hasattr(exp, "root_type"): - self.root_type = exp.root_type + if hasattr(exp, "root_types"): + self.root_types = exp.root_types def __str__(self): return "(%s)" % self.expression diff --git a/stix2/test/v21/test_pattern_expressions.py b/stix2/test/v21/test_pattern_expressions.py index 58cef3e..4f365d7 100644 --- a/stix2/test/v21/test_pattern_expressions.py +++ b/stix2/test/v21/test_pattern_expressions.py @@ -364,7 +364,7 @@ def test_parsing_or_observable_expression(): assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa -def test_invalid_and_observable_expression(): +def test_invalid_and_comparison_expression(): with pytest.raises(ValueError): stix2.AndBooleanExpression([ stix2.EqualityComparisonExpression( @@ -378,6 +378,33 @@ def test_invalid_and_observable_expression(): ]) +@pytest.mark.parametrize( + "pattern, root_types", [ + ("[a:a=1 AND a:b=1]", {"a"}), + ("[a:a=1 AND a:b=1 OR c:d=1]", {"a", "c"}), + ("[a:a=1 AND (a:b=1 OR c:d=1)]", {"a"}), + ("[(a:a=1 OR b:a=1) AND (b:a=1 OR c:c=1)]", {"b"}), + ("[(a:a=1 AND a:b=1) OR (b:a=1 AND b:c=1)]", {"a", "b"}), + ], +) +def test_comparison_expression_root_types(pattern, root_types): + ast = create_pattern_object(pattern) + assert ast.operand.root_types == root_types + + +@pytest.mark.parametrize( + "pattern", [ + "[a:b=1 AND b:c=1]", + "[a:b=1 AND (b:c=1 OR c:d=1)]", + "[(a:b=1 OR b:c=1) AND (c:d=1 OR d:e=1)]", + "[(a:b=1 AND b:c=1) OR (b:c=1 AND c:d=1)]", + ], +) +def test_comparison_expression_root_types_error(pattern): + with pytest.raises(ValueError): + create_pattern_object(pattern) + + def test_hex(): exp_and = stix2.AndBooleanExpression([ stix2.EqualityComparisonExpression(