diff --git a/stix2/STIXPatternVisitor.py b/stix2/STIXPatternVisitor.py index 8690ac4..a000ef0 100644 --- a/stix2/STIXPatternVisitor.py +++ b/stix2/STIXPatternVisitor.py @@ -11,14 +11,8 @@ from stix2patterns.validator import STIXPatternErrorListener from antlr4 import CommonTokenStream, InputStream -from .patterns import ( - BasicObjectPathComponent, BinaryConstant, BooleanConstant, FloatConstant, - FollowedByObservationExpression, HexConstant, IntegerConstant, - IsSubsetComparisonExpression, IsSupersetComparisonExpression, - ListObjectPathComponent, RepeatQualifier, StartStopQualifier, - StringConstant, TimestampConstant, WithinQualifier, -) - +# need to import all classes because we need to access them via globals() +from .patterns import * def collapse_lists(lists): result = [] @@ -30,11 +24,7 @@ def collapse_lists(lists): return result -def quote_if_needed(x): - if x.find("-") != -1: - return "'" + x + "'" - else: - return x + # This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. @@ -54,12 +44,16 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): super(STIXPatternVisitor, self).__init__() def get_class(self, class_name): - return STIXPatternVisitorForSTIX2.classes[class_name] + if class_name in STIXPatternVisitorForSTIX2.classes: + return STIXPatternVisitorForSTIX2.classes[class_name] + else: + return None def instantiate(self, klass_name, *args): + klass_to_instantiate = None if self.module_suffix: klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix) - else: + if not klass_to_instantiate: # use the classes in python_stix2 klass_to_instantiate = globals()[klass_name] return klass_to_instantiate(*args) @@ -214,11 +208,11 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): while i < len(flat_list): current = flat_list[i] if i == len(flat_list)-1: - property_path.append(quote_if_needed(current)) + property_path.append(current) break next = flat_list[i+1] if isinstance(next, TerminalNode): - property_path.append(ListObjectPathComponent(current.property_name, next.getText())) + property_path.append(self.instantiate("ListObjectPathComponent", current.property_name, next.getText())) i += 2 else: property_path.append(current) @@ -237,7 +231,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): # if step.endswith("_ref"): # return stix2.ReferenceObjectPathComponent(step) # else: - return BasicObjectPathComponent(step) + return self.instantiate("BasicObjectPathComponent", step, False) # Visit a parse tree produced by STIXPatternParser#indexPathStep. def visitIndexPathStep(self, ctx): @@ -255,7 +249,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): # special case for hashes return children[1].value else: - return BasicObjectPathComponent(children[1].getText(), is_key=True) + return self.instantiate("BasicObjectPathComponent", children[1].getText(), True) # Visit a parse tree produced by STIXPatternParser#setLiteral. def visitSetLiteral(self, ctx): diff --git a/stix2/patterns.py b/stix2/patterns.py index 8bcaea3..b037fdd 100644 --- a/stix2/patterns.py +++ b/stix2/patterns.py @@ -13,6 +13,14 @@ def escape_quotes_and_backslashes(s): return s.replace(u'\\', u'\\\\').replace(u"'", u"\\'") +def quote_if_needed(x): + if isinstance(x, str): + if x.find("-") != -1: + if not x.startswith("'"): + return "'" + x + "'" + return x + + class _Constant(object): pass @@ -229,7 +237,10 @@ class _ObjectPathComponent(object): parse1 = component_name.split("[") return ListObjectPathComponent(parse1[0], parse1[1][:-1]) else: - return BasicObjectPathComponent(component_name) + return BasicObjectPathComponent(component_name, False) + + def __str__(self): + return quote_if_needed(self.property_name) class BasicObjectPathComponent(_ObjectPathComponent): @@ -243,14 +254,11 @@ class BasicObjectPathComponent(_ObjectPathComponent): property_name (str): object property name is_key (bool): is dictionary key, default: False """ - def __init__(self, property_name, is_key=False): + def __init__(self, property_name, is_key): self.property_name = property_name # TODO: set is_key to True if this component is a dictionary key # self.is_key = is_key - def __str__(self): - return self.property_name - class ListObjectPathComponent(_ObjectPathComponent): """List object path component (for an observation or expression) @@ -264,7 +272,7 @@ class ListObjectPathComponent(_ObjectPathComponent): self.index = index def __str__(self): - return "%s[%s]" % (self.property_name, self.index) + return "%s[%s]" % (quote_if_needed(self.property_name), self.index) class ReferenceObjectPathComponent(_ObjectPathComponent): @@ -276,9 +284,6 @@ class ReferenceObjectPathComponent(_ObjectPathComponent): def __init__(self, reference_property_name): self.property_name = reference_property_name - def __str__(self): - return self.property_name - class ObjectPath(object): """Pattern operand object (property) path @@ -296,7 +301,7 @@ class ObjectPath(object): ] def __str__(self): - return "%s:%s" % (self.object_type_name, ".".join(["%s" % x for x in self.property_path])) + return "%s:%s" % (self.object_type_name, ".".join(["%s" % quote_if_needed(x) for x in self.property_path])) def merge(self, other): """Extend the object property with that of the supplied object property path""" diff --git a/stix2/test/test_pattern_expressions.py b/stix2/test/test_pattern_expressions.py index b08e72f..a6c1550 100644 --- a/stix2/test/test_pattern_expressions.py +++ b/stix2/test/test_pattern_expressions.py @@ -34,7 +34,7 @@ def test_boolean_expression_with_parentheses(): "email-message", [ stix2.ReferenceObjectPathComponent("from_ref"), - stix2.BasicObjectPathComponent("value"), + stix2.BasicObjectPathComponent("value", False), ], ), stix2.StringConstant(".+\\@example\\.com$"), @@ -159,18 +159,18 @@ def test_artifact_payload(): def test_greater_than_python_constant(): - exp1 = stix2.GreaterThanComparisonExpression("file:extensions.windows-pebinary-ext.sections[*].entropy", 7.0) + exp1 = stix2.GreaterThanComparisonExpression("file:extensions.'windows-pebinary-ext'.sections[*].entropy", 7.0) exp = stix2.ObservationExpression(exp1) - assert str(exp) == "[file:extensions.windows-pebinary-ext.sections[*].entropy > 7.0]" + assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]" def test_greater_than(): exp1 = stix2.GreaterThanComparisonExpression( - "file:extensions.windows-pebinary-ext.sections[*].entropy", + "file:extensions.'windows-pebinary-ext'.sections[*].entropy", stix2.FloatConstant(7.0), ) exp = stix2.ObservationExpression(exp1) - assert str(exp) == "[file:extensions.windows-pebinary-ext.sections[*].entropy > 7.0]" + assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]" def test_less_than():