Fix object type tracking for comparison expressions in the pattern
AST module.pull/1/head
parent
72a032c6e3
commit
7d64764ae3
|
@ -366,7 +366,7 @@ class _ComparisonExpression(_PatternExpression):
|
|||
else:
|
||||
self.rhs = make_constant(rhs)
|
||||
self.negated = negated
|
||||
self.root_type = self.lhs.object_type_name
|
||||
self.root_types = {self.lhs.object_type_name}
|
||||
|
||||
def __str__(self):
|
||||
if self.negated:
|
||||
|
@ -506,15 +506,17 @@ class _BooleanExpression(_PatternExpression):
|
|||
"""
|
||||
def __init__(self, operator, operands):
|
||||
self.operator = operator
|
||||
self.operands = []
|
||||
self.operands = list(operands)
|
||||
for arg in operands:
|
||||
if not hasattr(self, "root_type"):
|
||||
self.root_type = arg.root_type
|
||||
elif self.root_type and (self.root_type != arg.root_type) and operator == "AND":
|
||||
raise ValueError("All operands to an 'AND' expression must have the same object type")
|
||||
elif self.root_type and (self.root_type != arg.root_type):
|
||||
self.root_type = None
|
||||
self.operands.append(arg)
|
||||
if not hasattr(self, "root_types"):
|
||||
self.root_types = arg.root_types
|
||||
elif operator == "AND":
|
||||
self.root_types &= arg.root_types
|
||||
else:
|
||||
self.root_types |= arg.root_types
|
||||
|
||||
if not self.root_types:
|
||||
raise ValueError("All operands to an 'AND' expression must be satisfiable with the same object type")
|
||||
|
||||
def __str__(self):
|
||||
sub_exprs = []
|
||||
|
@ -613,8 +615,8 @@ class ParentheticalExpression(_PatternExpression):
|
|||
"""
|
||||
def __init__(self, exp):
|
||||
self.expression = exp
|
||||
if hasattr(exp, "root_type"):
|
||||
self.root_type = exp.root_type
|
||||
if hasattr(exp, "root_types"):
|
||||
self.root_types = exp.root_types
|
||||
|
||||
def __str__(self):
|
||||
return "(%s)" % self.expression
|
||||
|
|
|
@ -362,7 +362,7 @@ def test_parsing_or_observable_expression():
|
|||
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
|
||||
|
||||
|
||||
def test_invalid_and_observable_expression():
|
||||
def test_invalid_and_comparison_expression():
|
||||
with pytest.raises(ValueError):
|
||||
stix2.AndBooleanExpression([
|
||||
stix2.EqualityComparisonExpression(
|
||||
|
@ -376,6 +376,33 @@ def test_invalid_and_observable_expression():
|
|||
])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern, root_types", [
|
||||
("[a:a=1 AND a:b=1]", {"a"}),
|
||||
("[a:a=1 AND a:b=1 OR c:d=1]", {"a", "c"}),
|
||||
("[a:a=1 AND (a:b=1 OR c:d=1)]", {"a"}),
|
||||
("[(a:a=1 OR b:a=1) AND (b:a=1 OR c:c=1)]", {"b"}),
|
||||
("[(a:a=1 AND a:b=1) OR (b:a=1 AND b:c=1)]", {"a", "b"}),
|
||||
],
|
||||
)
|
||||
def test_comparison_expression_root_types(pattern, root_types):
|
||||
ast = create_pattern_object(pattern)
|
||||
assert ast.operand.root_types == root_types
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern", [
|
||||
"[a:b=1 AND b:c=1]",
|
||||
"[a:b=1 AND (b:c=1 OR c:d=1)]",
|
||||
"[(a:b=1 OR b:c=1) AND (c:d=1 OR d:e=1)]",
|
||||
"[(a:b=1 AND b:c=1) OR (b:c=1 AND c:d=1)]",
|
||||
],
|
||||
)
|
||||
def test_comparison_expression_root_types_error(pattern):
|
||||
with pytest.raises(ValueError):
|
||||
create_pattern_object(pattern)
|
||||
|
||||
|
||||
def test_hex():
|
||||
exp_and = stix2.AndBooleanExpression([
|
||||
stix2.EqualityComparisonExpression(
|
||||
|
|
Loading…
Reference in New Issue