Merge pull request #456 from chisholm/fix_comparison_expression_root_type
Fix object type tracking for AST comparison expression 'AND'pull/1/head
commit
ddb25c544a
|
@ -366,7 +366,7 @@ class _ComparisonExpression(_PatternExpression):
|
||||||
else:
|
else:
|
||||||
self.rhs = make_constant(rhs)
|
self.rhs = make_constant(rhs)
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
self.root_type = self.lhs.object_type_name
|
self.root_types = {self.lhs.object_type_name}
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.negated:
|
if self.negated:
|
||||||
|
@ -506,15 +506,17 @@ class _BooleanExpression(_PatternExpression):
|
||||||
"""
|
"""
|
||||||
def __init__(self, operator, operands):
|
def __init__(self, operator, operands):
|
||||||
self.operator = operator
|
self.operator = operator
|
||||||
self.operands = []
|
self.operands = list(operands)
|
||||||
for arg in operands:
|
for arg in operands:
|
||||||
if not hasattr(self, "root_type"):
|
if not hasattr(self, "root_types"):
|
||||||
self.root_type = arg.root_type
|
self.root_types = arg.root_types
|
||||||
elif self.root_type and (self.root_type != arg.root_type) and operator == "AND":
|
elif operator == "AND":
|
||||||
raise ValueError("All operands to an 'AND' expression must have the same object type")
|
self.root_types &= arg.root_types
|
||||||
elif self.root_type and (self.root_type != arg.root_type):
|
else:
|
||||||
self.root_type = None
|
self.root_types |= arg.root_types
|
||||||
self.operands.append(arg)
|
|
||||||
|
if not self.root_types:
|
||||||
|
raise ValueError("All operands to an 'AND' expression must be satisfiable with the same object type")
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
sub_exprs = []
|
sub_exprs = []
|
||||||
|
@ -613,8 +615,8 @@ class ParentheticalExpression(_PatternExpression):
|
||||||
"""
|
"""
|
||||||
def __init__(self, exp):
|
def __init__(self, exp):
|
||||||
self.expression = exp
|
self.expression = exp
|
||||||
if hasattr(exp, "root_type"):
|
if hasattr(exp, "root_types"):
|
||||||
self.root_type = exp.root_type
|
self.root_types = exp.root_types
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "(%s)" % self.expression
|
return "(%s)" % self.expression
|
||||||
|
|
|
@ -364,7 +364,7 @@ def test_parsing_or_observable_expression():
|
||||||
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
|
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_and_observable_expression():
|
def test_invalid_and_comparison_expression():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
stix2.AndBooleanExpression([
|
stix2.AndBooleanExpression([
|
||||||
stix2.EqualityComparisonExpression(
|
stix2.EqualityComparisonExpression(
|
||||||
|
@ -378,6 +378,33 @@ def test_invalid_and_observable_expression():
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pattern, root_types", [
|
||||||
|
("[a:a=1 AND a:b=1]", {"a"}),
|
||||||
|
("[a:a=1 AND a:b=1 OR c:d=1]", {"a", "c"}),
|
||||||
|
("[a:a=1 AND (a:b=1 OR c:d=1)]", {"a"}),
|
||||||
|
("[(a:a=1 OR b:a=1) AND (b:a=1 OR c:c=1)]", {"b"}),
|
||||||
|
("[(a:a=1 AND a:b=1) OR (b:a=1 AND b:c=1)]", {"a", "b"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_comparison_expression_root_types(pattern, root_types):
|
||||||
|
ast = create_pattern_object(pattern)
|
||||||
|
assert ast.operand.root_types == root_types
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pattern", [
|
||||||
|
"[a:b=1 AND b:c=1]",
|
||||||
|
"[a:b=1 AND (b:c=1 OR c:d=1)]",
|
||||||
|
"[(a:b=1 OR b:c=1) AND (c:d=1 OR d:e=1)]",
|
||||||
|
"[(a:b=1 AND b:c=1) OR (b:c=1 AND c:d=1)]",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_comparison_expression_root_types_error(pattern):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
create_pattern_object(pattern)
|
||||||
|
|
||||||
|
|
||||||
def test_hex():
|
def test_hex():
|
||||||
exp_and = stix2.AndBooleanExpression([
|
exp_and = stix2.AndBooleanExpression([
|
||||||
stix2.EqualityComparisonExpression(
|
stix2.EqualityComparisonExpression(
|
||||||
|
|
Loading…
Reference in New Issue