Fix object type tracking for comparison expressions in the pattern

AST module.
pull/1/head
Michael Chisholm 2020-09-14 15:42:36 -04:00
parent 72a032c6e3
commit 7d64764ae3
2 changed files with 41 additions and 12 deletions

View File

@ -366,7 +366,7 @@ class _ComparisonExpression(_PatternExpression):
else:
self.rhs = make_constant(rhs)
self.negated = negated
self.root_type = self.lhs.object_type_name
self.root_types = {self.lhs.object_type_name}
def __str__(self):
if self.negated:
@ -506,15 +506,17 @@ class _BooleanExpression(_PatternExpression):
"""
def __init__(self, operator, operands):
self.operator = operator
self.operands = []
self.operands = list(operands)
for arg in operands:
if not hasattr(self, "root_type"):
self.root_type = arg.root_type
elif self.root_type and (self.root_type != arg.root_type) and operator == "AND":
raise ValueError("All operands to an 'AND' expression must have the same object type")
elif self.root_type and (self.root_type != arg.root_type):
self.root_type = None
self.operands.append(arg)
if not hasattr(self, "root_types"):
self.root_types = arg.root_types
elif operator == "AND":
self.root_types &= arg.root_types
else:
self.root_types |= arg.root_types
if not self.root_types:
raise ValueError("All operands to an 'AND' expression must be satisfiable with the same object type")
def __str__(self):
sub_exprs = []
@ -613,8 +615,8 @@ class ParentheticalExpression(_PatternExpression):
"""
def __init__(self, exp):
self.expression = exp
if hasattr(exp, "root_type"):
self.root_type = exp.root_type
if hasattr(exp, "root_types"):
self.root_types = exp.root_types
def __str__(self):
return "(%s)" % self.expression

View File

@ -362,7 +362,7 @@ def test_parsing_or_observable_expression():
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
def test_invalid_and_observable_expression():
def test_invalid_and_comparison_expression():
with pytest.raises(ValueError):
stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
@ -376,6 +376,33 @@ def test_invalid_and_observable_expression():
])
@pytest.mark.parametrize(
"pattern, root_types", [
("[a:a=1 AND a:b=1]", {"a"}),
("[a:a=1 AND a:b=1 OR c:d=1]", {"a", "c"}),
("[a:a=1 AND (a:b=1 OR c:d=1)]", {"a"}),
("[(a:a=1 OR b:a=1) AND (b:a=1 OR c:c=1)]", {"b"}),
("[(a:a=1 AND a:b=1) OR (b:a=1 AND b:c=1)]", {"a", "b"}),
],
)
def test_comparison_expression_root_types(pattern, root_types):
ast = create_pattern_object(pattern)
assert ast.operand.root_types == root_types
@pytest.mark.parametrize(
"pattern", [
"[a:b=1 AND b:c=1]",
"[a:b=1 AND (b:c=1 OR c:d=1)]",
"[(a:b=1 OR b:c=1) AND (c:d=1 OR d:e=1)]",
"[(a:b=1 AND b:c=1) OR (b:c=1 AND c:d=1)]",
],
)
def test_comparison_expression_root_types_error(pattern):
with pytest.raises(ValueError):
create_pattern_object(pattern)
def test_hex():
exp_and = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(