cti-python-stix2/stix2/pattern_visitor.py

359 lines
15 KiB
Python

import importlib
import inspect
from stix2patterns.exceptions import ParseException
from stix2patterns.grammars.STIXPatternParser import TerminalNode
from stix2patterns.v20.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser20
from stix2patterns.v21.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser21
from stix2patterns.v20.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor20
from stix2patterns.v21.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor21
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
from stix2patterns.v20.pattern import Pattern as Pattern20
from stix2patterns.v21.pattern import Pattern as Pattern21
from .patterns import *
from .patterns import _BooleanExpression
# flake8: noqa F405
def collapse_lists(lists):
result = []
for c in lists:
if isinstance(c, list):
result.extend(c)
else:
result.append(c)
return result
def remove_terminal_nodes(parse_tree_nodes):
values = []
for x in parse_tree_nodes:
if not isinstance(x, TerminalNode):
values.append(x)
return values
class STIXPatternVisitorForSTIX2():
classes = {}
def get_class(self, class_name):
if class_name in STIXPatternVisitorForSTIX2.classes:
return STIXPatternVisitorForSTIX2.classes[class_name]
else:
return None
def instantiate(self, klass_name, *args):
klass_to_instantiate = None
if self.module_suffix:
klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix)
if not klass_to_instantiate:
# use the classes in python_stix2
klass_to_instantiate = globals()[klass_name]
return klass_to_instantiate(*args)
# Visit a parse tree produced by STIXPatternParser#pattern.
def visitPattern(self, ctx):
children = self.visitChildren(ctx)
return children[0]
# Visit a parse tree produced by STIXPatternParser#observationExpressions.
def visitObservationExpressions(self, ctx):
children = self.visitChildren(ctx)
if len(children) == 1:
return children[0]
else:
return FollowedByObservationExpression([children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#observationExpressionOr.
def visitObservationExpressionOr(self, ctx):
children = self.visitChildren(ctx)
if len(children) == 1:
return children[0]
else:
return self.instantiate("OrObservationExpression", [children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#observationExpressionAnd.
def visitObservationExpressionAnd(self, ctx):
children = self.visitChildren(ctx)
if len(children) == 1:
return children[0]
else:
return self.instantiate("AndObservationExpression", [children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#observationExpressionRepeated.
def visitObservationExpressionRepeated(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("QualifiedObservationExpression", children[0], children[1])
# Visit a parse tree produced by STIXPatternParser#observationExpressionSimple.
def visitObservationExpressionSimple(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("ObservationExpression", children[1])
# Visit a parse tree produced by STIXPatternParser#observationExpressionCompound.
def visitObservationExpressionCompound(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("ObservationExpression", children[1])
# Visit a parse tree produced by STIXPatternParser#observationExpressionWithin.
def visitObservationExpressionWithin(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("QualifiedObservationExpression", children[0], children[1])
# Visit a parse tree produced by STIXPatternParser#observationExpressionStartStop.
def visitObservationExpressionStartStop(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("QualifiedObservationExpression", children[0], children[1])
# Visit a parse tree produced by STIXPatternParser#comparisonExpression.
def visitComparisonExpression(self, ctx):
children = self.visitChildren(ctx)
if len(children) == 1:
return children[0]
else:
if isinstance(children[0], _BooleanExpression):
children[0].operands.append(children[2])
return children[0]
else:
return self.instantiate("OrBooleanExpression", [children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#comparisonExpressionAnd.
def visitComparisonExpressionAnd(self, ctx):
# TODO: NOT
children = self.visitChildren(ctx)
if len(children) == 1:
return children[0]
else:
if isinstance(children[0], _BooleanExpression):
children[0].operands.append(children[2])
return children[0]
else:
return self.instantiate("AndBooleanExpression", [children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#propTestEqual.
def visitPropTestEqual(self, ctx):
children = self.visitChildren(ctx)
operator = children[1].symbol.type
negated = operator != self.parser_class.EQ
return self.instantiate(
"EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2],
negated,
)
# Visit a parse tree produced by STIXPatternParser#propTestOrder.
def visitPropTestOrder(self, ctx):
children = self.visitChildren(ctx)
operator = children[1].symbol.type
if operator == self.parser_class.GT:
return self.instantiate(
"GreaterThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == self.parser_class.LT:
return self.instantiate(
"LessThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == self.parser_class.GE:
return self.instantiate(
"GreaterThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == self.parser_class.LE:
return self.instantiate(
"LessThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
# Visit a parse tree produced by STIXPatternParser#propTestSet.
def visitPropTestSet(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("InComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False)
# Visit a parse tree produced by STIXPatternParser#propTestLike.
def visitPropTestLike(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("LikeComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False)
# Visit a parse tree produced by STIXPatternParser#propTestRegex.
def visitPropTestRegex(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate(
"MatchesComparisonExpression", children[0], children[3 if len(children) > 3 else 2],
False,
)
# Visit a parse tree produced by STIXPatternParser#propTestIsSubset.
def visitPropTestIsSubset(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("IsSubsetComparisonExpression", children[0], children[3 if len(children) > 3 else 2])
# Visit a parse tree produced by STIXPatternParser#propTestIsSuperset.
def visitPropTestIsSuperset(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("IsSupersetComparisonExpression", children[0], children[3 if len(children) > 3 else 2])
# Visit a parse tree produced by STIXPatternParser#propTestParen.
def visitPropTestParen(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("ParentheticalExpression", children[1])
# Visit a parse tree produced by STIXPatternParser#startStopQualifier.
def visitStartStopQualifier(self, ctx):
children = self.visitChildren(ctx)
return StartStopQualifier(children[1], children[3])
# Visit a parse tree produced by STIXPatternParser#withinQualifier.
def visitWithinQualifier(self, ctx):
children = self.visitChildren(ctx)
return WithinQualifier(children[1])
# Visit a parse tree produced by STIXPatternParser#repeatedQualifier.
def visitRepeatedQualifier(self, ctx):
children = self.visitChildren(ctx)
return RepeatQualifier(children[1])
# Visit a parse tree produced by STIXPatternParser#objectPath.
def visitObjectPath(self, ctx):
children = self.visitChildren(ctx)
flat_list = collapse_lists(children[2:])
property_path = []
i = 0
while i < len(flat_list):
current = flat_list[i]
if i == len(flat_list)-1:
property_path.append(current)
break
next = flat_list[i+1]
if isinstance(next, TerminalNode):
property_path.append(self.instantiate("ListObjectPathComponent", current.property_name, next.getText()))
i += 2
else:
property_path.append(current)
i += 1
return self.instantiate("ObjectPath", children[0].getText(), property_path)
# Visit a parse tree produced by STIXPatternParser#objectType.
def visitObjectType(self, ctx):
children = self.visitChildren(ctx)
return children[0]
# Visit a parse tree produced by STIXPatternParser#firstPathComponent.
def visitFirstPathComponent(self, ctx):
children = self.visitChildren(ctx)
step = children[0].getText()
# if step.endswith("_ref"):
# return stix2.ReferenceObjectPathComponent(step)
# else:
return self.instantiate("BasicObjectPathComponent", step, False)
# Visit a parse tree produced by STIXPatternParser#indexPathStep.
def visitIndexPathStep(self, ctx):
children = self.visitChildren(ctx)
return children[1]
# Visit a parse tree produced by STIXPatternParser#pathStep.
def visitPathStep(self, ctx):
return collapse_lists(self.visitChildren(ctx))
# Visit a parse tree produced by STIXPatternParser#keyPathStep.
def visitKeyPathStep(self, ctx):
children = self.visitChildren(ctx)
if isinstance(children[1], StringConstant):
# special case for hashes
return children[1].value
else:
return self.instantiate("BasicObjectPathComponent", children[1].getText(), True)
# Visit a parse tree produced by STIXPatternParser#setLiteral.
def visitSetLiteral(self, ctx):
children = self.visitChildren(ctx)
return self.instantiate("ListConstant", remove_terminal_nodes(children))
# Visit a parse tree produced by STIXPatternParser#primitiveLiteral.
def visitPrimitiveLiteral(self, ctx):
children = self.visitChildren(ctx)
return children[0]
# Visit a parse tree produced by STIXPatternParser#orderableLiteral.
def visitOrderableLiteral(self, ctx):
children = self.visitChildren(ctx)
return children[0]
def visitTerminal(self, node):
if node.symbol.type == self.parser_class.IntPosLiteral or node.symbol.type == self.parser_class.IntNegLiteral:
return IntegerConstant(node.getText())
elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral:
return FloatConstant(node.getText())
elif node.symbol.type == self.parser_class.HexLiteral:
return HexConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == self.parser_class.BinaryLiteral:
return BinaryConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == self.parser_class.StringLiteral:
if node.getText()[0] == "'" and node.getText()[-1] == "'":
return StringConstant(node.getText()[1:-1], from_parse_tree=True)
else:
raise ParseException("The pattern does not start and end with a single quote")
elif node.symbol.type == self.parser_class.BoolLiteral:
return BooleanConstant(node.getText())
elif node.symbol.type == self.parser_class.TimestampLiteral:
return TimestampConstant(node.getText())
else:
return node
def aggregateResult(self, aggregate, nextResult):
if aggregate:
aggregate.append(nextResult)
elif nextResult:
aggregate = [nextResult]
return aggregate
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser21
super(STIXPatternVisitor21, self).__init__()
class STIXPatternVisitorForSTIX20(STIXPatternVisitor20, STIXPatternVisitorForSTIX2):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser20
super(STIXPatternVisitor20, self).__init__()
def create_pattern_object(pattern, module_suffix="", module_name="", version="2.1"):
"""
Create a STIX pattern AST from a pattern string.
"""
pattern_obj = Pattern21(pattern) if version == "2.1" else Pattern20(pattern)
builder = STIXPatternVisitorForSTIX21(module_suffix, module_name) if version == "2.1" else STIXPatternVisitorForSTIX20(module_suffix, module_name)
return pattern_obj.visit(builder)