From 1142d729df95b1fd32edaec5dcce6a0437437053 Mon Sep 17 00:00:00 2001 From: Richard Piazza Date: Sun, 11 Nov 2018 10:09:48 -0500 Subject: [PATCH] Add generic vistor --- stix2/STIXPatternVisitor.py | 119 ++++++++++++++++++++++++++++-------- 1 file changed, 92 insertions(+), 27 deletions(-) diff --git a/stix2/STIXPatternVisitor.py b/stix2/STIXPatternVisitor.py index e4bb9d8..e55795b 100644 --- a/stix2/STIXPatternVisitor.py +++ b/stix2/STIXPatternVisitor.py @@ -1,8 +1,14 @@ -# Generated from STIXPattern.g4 by ANTLR 4.7 import stix2 -import stix2patterns from antlr4 import * -from stix2patterns.pattern_grammar.STIXPatternParser import * +import six +from stix2patterns.grammars.STIXPatternParser import * +from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor +from antlr4 import CommonTokenStream, InputStream +from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer +from stix2patterns.grammars.STIXPatternParser import STIXPatternParser +from stix2patterns.validator import STIXPatternErrorListener +import importlib +import inspect def collapse_lists(lists): @@ -16,7 +22,27 @@ def collapse_lists(lists): # This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. -class STIXPatternVisitor(ParseTreeVisitor): +class STIXPatternVisitorForSTIX2(STIXPatternVisitor): + classes = {} + + def __init__(self, module_suffix, module_name): + self.module_suffix = module_suffix + if STIXPatternVisitorForSTIX2.classes == {}: + module = importlib.import_module(module_name) + for k, c in inspect.getmembers(module, inspect.isclass): + STIXPatternVisitorForSTIX2.classes[k] = c + super(STIXPatternVisitor, self).__init__() + + def get_class(self, class_name): + return STIXPatternVisitorForSTIX2.classes[class_name] + + def instantiate(self, klass_name, *args): + if self.module_suffix: + klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix) + else: + # use the classes in python_stix2 + klass_to_instantiate = globals()[klass_name] + return klass_to_instantiate(*args) # Visit a parse tree produced by STIXPatternParser#pattern. def visitPattern(self, ctx): @@ -32,15 +58,13 @@ class STIXPatternVisitor(ParseTreeVisitor): else: return stix2.FollowedByObservationExpression([children[0], children[2]]) - - # Visit a parse tree produced by STIXPatternParser#observationExpressionOr. def visitObservationExpressionOr(self, ctx): children = self.visitChildren(ctx) if len(children) == 1: return children[0] else: - return stix2._OrObservationExpression([children[0], children[2]]) + return self.instantiate("OrObservationExpression", [children[0], children[2]]) # Visit a parse tree produced by STIXPatternParser#observationExpressionAnd. @@ -49,37 +73,37 @@ class STIXPatternVisitor(ParseTreeVisitor): if len(children) == 1: return children[0] else: - return AndObservationExpressionForSlider([children[0], children[2]]) + return self.instantiate("AndObservationExpression", [children[0], children[2]]) # Visit a parse tree produced by STIXPatternParser#observationExpressionRepeated. def visitObservationExpressionRepeated(self, ctx): children = self.visitChildren(ctx) - return QualifiedObservationExpressionForSlider(children[0], children[1]) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) # Visit a parse tree produced by STIXPatternParser#observationExpressionSimple. def visitObservationExpressionSimple(self, ctx): children = self.visitChildren(ctx) - return ObservationExpressionForSlider(children[1]) + return self.instantiate("ObservationExpression", children[1]) # Visit a parse tree produced by STIXPatternParser#observationExpressionCompound. def visitObservationExpressionCompound(self, ctx): children = self.visitChildren(ctx) - return ObservationExpressionForSlider(children[1]) + return self.instantiate("ObservationExpression", children[1]) # Visit a parse tree produced by STIXPatternParser#observationExpressionWithin. def visitObservationExpressionWithin(self, ctx): children = self.visitChildren(ctx) - return QualifiedObservationExpressionForSlider(children[0], children[1]) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) # Visit a parse tree produced by STIXPatternParser#observationExpressionStartStop. def visitObservationExpressionStartStop(self, ctx): children = self.visitChildren(ctx) - return QualifiedObservationExpressionForSlider(children[0], children[1]) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) # Visit a parse tree produced by STIXPatternParser#comparisonExpression. @@ -88,7 +112,7 @@ class STIXPatternVisitor(ParseTreeVisitor): if len(children) == 1: return children[0] else: - return OrBooleanExpressionForSlider([children[0], children[2]]) + return self.instantiate("OrBooleanExpression", [children[0], children[2]]) # Visit a parse tree produced by STIXPatternParser#comparisonExpressionAnd. @@ -98,7 +122,7 @@ class STIXPatternVisitor(ParseTreeVisitor): if len(children) == 1: return children[0] else: - return AndBooleanExpressionForSlider([children[0], children[2]]) + return self.instantiate("AndBooleanExpression", [children[0], children[2]]) # Visit a parse tree produced by STIXPatternParser#propTestEqual. @@ -107,11 +131,11 @@ class STIXPatternVisitor(ParseTreeVisitor): if len(children) == 4: operator = children[2].symbol.type negated = negated=operator == STIXPatternParser.EQ - return EqualityComparisonExpressionForSlider(children[0], children[3], negated=negated) + return self.instantiate("EqualityComparisonExpression", children[0], children[3], negated) else: operator = children[1].symbol.type negated = negated = operator != STIXPatternParser.EQ - return EqualityComparisonExpressionForSlider(children[0], children[2], negated=negated) + return self.instantiate("EqualityComparisonExpression", children[0], children[2], negated) # Visit a parse tree produced by STIXPatternParser#propTestOrder. @@ -119,13 +143,13 @@ class STIXPatternVisitor(ParseTreeVisitor): children = self.visitChildren(ctx) operator = children[1].symbol.type if operator == STIXPatternParser.GT: - return GreaterThanComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("GreaterThanComparisonExpression", children[0], children[2], False) elif operator == STIXPatternParser.LT: - return LessThanComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("LessThanComparisonExpression", children[0], children[2], False) elif operator == STIXPatternParser.GE: - return GreaterThanEqualComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("GreaterThanEqualComparisonExpression", children[0], children[2], False) elif operator == STIXPatternParser.LE: - return LessThanEqualComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("LessThanEqualComparisonExpression", children[0], children[2], False) # Visit a parse tree produced by STIXPatternParser#propTestSet. @@ -137,13 +161,13 @@ class STIXPatternVisitor(ParseTreeVisitor): # Visit a parse tree produced by STIXPatternParser#propTestLike. def visitPropTestLike(self, ctx): children = self.visitChildren(ctx) - return LikeComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("LikeComparisonExpression", children[0], children[2], False) # Visit a parse tree produced by STIXPatternParser#propTestRegex. def visitPropTestRegex(self, ctx): children = self.visitChildren(ctx) - return MatchesComparisonExpressionForSlider(children[0], children[2]) + return self.instantiate("MatchesComparisonExpression", children[0], children[2], False) # Visit a parse tree produced by STIXPatternParser#propTestIsSubset. @@ -161,7 +185,7 @@ class STIXPatternVisitor(ParseTreeVisitor): # Visit a parse tree produced by STIXPatternParser#propTestParen. def visitPropTestParen(self, ctx): children = self.visitChildren(ctx) - return ParentheticalExpressionForSlider(children[1]) + return self.instantiate("ParentheticalExpression", children[1]) # Visit a parse tree produced by STIXPatternParser#startStopQualifier. @@ -200,7 +224,7 @@ class STIXPatternVisitor(ParseTreeVisitor): else: property_path.append(current) i += 1 - return ObjectPathForSlider(children[0].getText(), property_path) + return self.instantiate("ObjectPath", children[0].getText(), property_path) # Visit a parse tree produced by STIXPatternParser#objectType. @@ -259,9 +283,9 @@ class STIXPatternVisitor(ParseTreeVisitor): def visitTerminal(self, node): - if node.symbol.type == STIXPatternParser.IntLiteral: + if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral: return stix2.IntegerConstant(node.getText()) - elif node.symbol.type == STIXPatternParser.FloatLiteral: + elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral: return stix2.FloatConstant(node.getText()) elif node.symbol.type == STIXPatternParser.HexLiteral: return stix2.HexConstant(node.getText()) @@ -283,3 +307,44 @@ class STIXPatternVisitor(ParseTreeVisitor): elif nextResult: aggregate = [nextResult] return aggregate + + +def create_pattern_object(pattern, module_suffix="", module_name=""): + ''' + Validates a pattern against the STIX Pattern grammar. Error messages are + returned in a list. The test passed if the returned list is empty. + ''' + + start = '' + if isinstance(pattern, six.string_types): + start = pattern[:2] + pattern = InputStream(pattern) + + if not start: + start = pattern.readline()[:2] + pattern.seek(0) + + parseErrListener = STIXPatternErrorListener() + + lexer = STIXPatternLexer(pattern) + # it always adds a console listener by default... remove it. + lexer.removeErrorListeners() + + stream = CommonTokenStream(lexer) + + parser = STIXPatternParser(stream) + parser.buildParseTrees = True + # it always adds a console listener by default... remove it. + parser.removeErrorListeners() + parser.addErrorListener(parseErrListener) + + # To improve error messages, replace "" in the literal + # names with symbolic names. This is a hack, but seemed like + # the simplest workaround. + for i, lit_name in enumerate(parser.literalNames): + if lit_name == u"": + parser.literalNames[i] = parser.symbolicNames[i] + + tree = parser.pattern() + builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) + return builder.visit(tree)