From c8bcece6f6b693b1cfd67113caac8eaede618201 Mon Sep 17 00:00:00 2001 From: Richard Piazza Date: Thu, 6 Jul 2017 10:06:24 -0400 Subject: [PATCH] added tests for expressions fix __str__ methods --- stix2/__init__.py | 4 +++- stix2/pattern_expressions.py | 26 ++++++++++++++------------ stix2/test/test_pattern_expressions.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/stix2/__init__.py b/stix2/__init__.py index b48742f..1974118 100644 --- a/stix2/__init__.py +++ b/stix2/__init__.py @@ -22,10 +22,12 @@ from .other import (TLP_AMBER, TLP_GREEN, TLP_RED, TLP_WHITE, MarkingDefinition, StatementMarking, TLPMarking) from .pattern_expressions import (AndBooleanExpression, AndObservableExpression, + ComparisonExpression, EqualityComparisonExpression, FollowedByObservableExpression, MatchesComparisonExpression, - ObservableExpression, OrBooleanExpression, + ObservableExpression, + OrBooleanExpression, OrObservableExpression, ParentheticalExpression, QualifiedObservationExpression, diff --git a/stix2/pattern_expressions.py b/stix2/pattern_expressions.py index 37c692c..e952db9 100644 --- a/stix2/pattern_expressions.py +++ b/stix2/pattern_expressions.py @@ -1,5 +1,3 @@ -from six import text_type - class PatternExpression(object): @@ -28,10 +26,10 @@ class ComparisonExpression(PatternExpression): if isinstance(self.rhs, list): final_rhs = [] for r in self.rhs: - final_rhs.append("'" + self.escape_quotes_and_backslashes(text_type(r)) + "'") + final_rhs.append("'" + self.escape_quotes_and_backslashes("%s" % r) + "'") rhs_string = "(" + ", ".join(final_rhs) + ")" else: - rhs_string = "'" + self.escape_quotes_and_backslashes(text_type(self.rhs)) + "'" + rhs_string = "'" + self.escape_quotes_and_backslashes("%s" % self.rhs) + "'" return self.lhs + (" NOT" if self.negated else "") + " " + self.operator + " " + rhs_string @@ -87,12 +85,14 @@ class BooleanExpression(PatternExpression): self.root_type = arg.root_type elif self.root_type and (self.root_type != arg.root_type) and operator == "AND": raise ValueError("This expression cannot have a mixed root type") + elif self.root_type and (self.root_type != arg.root_type): + self.root_type = None self.operands.append(arg) def __str__(self): sub_exprs = [] for o in self.operands: - sub_exprs.append(str(o)) + sub_exprs.append("%s" % o) return (" " + self.operator + " ").join(sub_exprs) @@ -111,7 +111,7 @@ class ObservableExpression(PatternExpression): self.operand = operand def __str__(self): - return "[" + str(self.operand) + "]" + return "[%s]" % self.operand class CompoundObservableExpression(PatternExpression): @@ -122,7 +122,7 @@ class CompoundObservableExpression(PatternExpression): def __str__(self): sub_exprs = [] for o in self.operands: - sub_exprs.append(str(o)) + sub_exprs.append("%s" % o) return (" " + self.operator + " ").join(sub_exprs) @@ -144,9 +144,11 @@ class FollowedByObservableExpression(CompoundObservableExpression): class ParentheticalExpression(PatternExpression): def __init__(self, exp): self.expression = exp + if hasattr(exp, "root_type"): + self.root_type = exp.root_type def __str__(self): - return "(" + str(self.expression) + ")" + return "(%s)" % self.expression class ExpressionQualifier(PatternExpression): @@ -158,7 +160,7 @@ class RepeatQualifier(ExpressionQualifier): self.times_to_repeat = times_to_repeat def __str__(self): - return "REPEATS %s TIMES" % str(self.times_to_repeat) + return "REPEATS %s TIMES" % self.times_to_repeat class WithinQualifier(ExpressionQualifier): @@ -166,7 +168,7 @@ class WithinQualifier(ExpressionQualifier): self.number_of_seconds = number_of_seconds def __str__(self): - return "WITHIN %s SECONDS" % (str(self.number_of_seconds)) + return "WITHIN %s SECONDS" % self.number_of_seconds class StartStopQualifier(ExpressionQualifier): @@ -175,7 +177,7 @@ class StartStopQualifier(ExpressionQualifier): self.stop_time = stop_time def __str__(self): - return "START %s STOP %s" % (str(self.start_time), str(self.stop_time)) + return "START %s STOP %s" % (self.start_time, self.stop_time) class QualifiedObservationExpression(PatternExpression): @@ -184,4 +186,4 @@ class QualifiedObservationExpression(PatternExpression): self.qualifier = qualifier def __str__(self): - return str(self.observation_expression) + " " + str(self.qualifier) + return "%s %s" % (self.observation_expression, self.qualifier) diff --git a/stix2/test/test_pattern_expressions.py b/stix2/test/test_pattern_expressions.py index b86febb..49338df 100644 --- a/stix2/test/test_pattern_expressions.py +++ b/stix2/test/test_pattern_expressions.py @@ -56,3 +56,14 @@ def test_multiple_file_observable_expression(): op2_exp = stix2.ObservableExpression(exp3) exp = stix2.AndObservableExpression([op1_exp, op2_exp]) assert str(exp) == "[file:hashes.'SHA-256' = 'bf07a7fbb825fc0aae7bf4a1177b2b31fcf8a3feeaf7092761e18c859ee52a9c' OR file:hashes.MD5 = 'cead3f77f6cda6ec00f57d76c9a6879f'] AND [file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']" # noqa + + +def test_root_types(): + ast = stix2.ObservableExpression( + stix2.AndBooleanExpression( + [stix2.ParentheticalExpression( + stix2.OrBooleanExpression([ + stix2.EqualityComparisonExpression(u"a:b", u"1"), + stix2.EqualityComparisonExpression(u"b:c", u"2")])), + stix2.EqualityComparisonExpression(u"b:d", u"3")])) + assert str(ast) == "[(a:b = '1' OR b:c = '2') AND b:d = '3']"