diff --git a/.gitignore b/.gitignore index 9e12d7d..5534a28 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,31 @@ cache.sqlite # PyCharm .idea/ +### macOS template +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + diff --git a/.isort.cfg b/.isort.cfg index ffbe786..d644f60 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -2,6 +2,7 @@ skip = workbench.py not_skip = __init__.py known_third_party = + antlr4, dateutil, medallion, pytest, diff --git a/stix2/pattern_visitor.py b/stix2/pattern_visitor.py new file mode 100644 index 0000000..94c8559 --- /dev/null +++ b/stix2/pattern_visitor.py @@ -0,0 +1,349 @@ +import importlib +import inspect + +from antlr4 import CommonTokenStream, InputStream +import six +from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer +from stix2patterns.grammars.STIXPatternParser import (STIXPatternParser, + TerminalNode) +from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor +from stix2patterns.validator import STIXPatternErrorListener + +from .patterns import * +from .patterns import _BooleanExpression + +# flake8: noqa F405 + + +def collapse_lists(lists): + result = [] + for c in lists: + if isinstance(c, list): + result.extend(c) + else: + result.append(c) + return result + + +def remove_terminal_nodes(parse_tree_nodes): + values = [] + for x in parse_tree_nodes: + if not isinstance(x, TerminalNode): + values.append(x) + return values + + +# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. + + +class STIXPatternVisitorForSTIX2(STIXPatternVisitor): + classes = {} + + def __init__(self, module_suffix, module_name): + if module_suffix and module_name: + self.module_suffix = module_suffix + if not STIXPatternVisitorForSTIX2.classes: + module = importlib.import_module(module_name) + for k, c in inspect.getmembers(module, inspect.isclass): + STIXPatternVisitorForSTIX2.classes[k] = c + else: + self.module_suffix = None + super(STIXPatternVisitor, self).__init__() + + def get_class(self, class_name): + if class_name in STIXPatternVisitorForSTIX2.classes: + return STIXPatternVisitorForSTIX2.classes[class_name] + else: + return None + + def instantiate(self, klass_name, *args): + klass_to_instantiate = None + if self.module_suffix: + klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix) + if not klass_to_instantiate: + # use the classes in python_stix2 + klass_to_instantiate = globals()[klass_name] + return klass_to_instantiate(*args) + + # Visit a parse tree produced by STIXPatternParser#pattern. + def visitPattern(self, ctx): + children = self.visitChildren(ctx) + return children[0] + + # Visit a parse tree produced by STIXPatternParser#observationExpressions. + def visitObservationExpressions(self, ctx): + children = self.visitChildren(ctx) + if len(children) == 1: + return children[0] + else: + return FollowedByObservationExpression([children[0], children[2]]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionOr. + def visitObservationExpressionOr(self, ctx): + children = self.visitChildren(ctx) + if len(children) == 1: + return children[0] + else: + return self.instantiate("OrObservationExpression", [children[0], children[2]]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionAnd. + def visitObservationExpressionAnd(self, ctx): + children = self.visitChildren(ctx) + if len(children) == 1: + return children[0] + else: + return self.instantiate("AndObservationExpression", [children[0], children[2]]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionRepeated. + def visitObservationExpressionRepeated(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionSimple. + def visitObservationExpressionSimple(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("ObservationExpression", children[1]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionCompound. + def visitObservationExpressionCompound(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("ObservationExpression", children[1]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionWithin. + def visitObservationExpressionWithin(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) + + # Visit a parse tree produced by STIXPatternParser#observationExpressionStartStop. + def visitObservationExpressionStartStop(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("QualifiedObservationExpression", children[0], children[1]) + + # Visit a parse tree produced by STIXPatternParser#comparisonExpression. + def visitComparisonExpression(self, ctx): + children = self.visitChildren(ctx) + if len(children) == 1: + return children[0] + else: + if isinstance(children[0], _BooleanExpression): + children[0].operands.append(children[2]) + return children[0] + else: + return self.instantiate("OrBooleanExpression", [children[0], children[2]]) + + # Visit a parse tree produced by STIXPatternParser#comparisonExpressionAnd. + def visitComparisonExpressionAnd(self, ctx): + # TODO: NOT + children = self.visitChildren(ctx) + if len(children) == 1: + return children[0] + else: + if isinstance(children[0], _BooleanExpression): + children[0].operands.append(children[2]) + return children[0] + else: + return self.instantiate("AndBooleanExpression", [children[0], children[2]]) + + # Visit a parse tree produced by STIXPatternParser#propTestEqual. + def visitPropTestEqual(self, ctx): + children = self.visitChildren(ctx) + operator = children[1].symbol.type + negated = operator != STIXPatternParser.EQ + return self.instantiate("EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2], + negated) + + # Visit a parse tree produced by STIXPatternParser#propTestOrder. + def visitPropTestOrder(self, ctx): + children = self.visitChildren(ctx) + operator = children[1].symbol.type + if operator == STIXPatternParser.GT: + return self.instantiate("GreaterThanComparisonExpression", children[0], + children[3 if len(children) > 3 else 2], False) + elif operator == STIXPatternParser.LT: + return self.instantiate("LessThanComparisonExpression", children[0], + children[3 if len(children) > 3 else 2], False) + elif operator == STIXPatternParser.GE: + return self.instantiate("GreaterThanEqualComparisonExpression", children[0], + children[3 if len(children) > 3 else 2], False) + elif operator == STIXPatternParser.LE: + return self.instantiate("LessThanEqualComparisonExpression", children[0], + children[3 if len(children) > 3 else 2], False) + + # Visit a parse tree produced by STIXPatternParser#propTestSet. + def visitPropTestSet(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("InComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False) + + # Visit a parse tree produced by STIXPatternParser#propTestLike. + def visitPropTestLike(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("LikeComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False) + + # Visit a parse tree produced by STIXPatternParser#propTestRegex. + def visitPropTestRegex(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("MatchesComparisonExpression", children[0], children[3 if len(children) > 3 else 2], + False) + + # Visit a parse tree produced by STIXPatternParser#propTestIsSubset. + def visitPropTestIsSubset(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("IsSubsetComparisonExpression", children[0], children[3 if len(children) > 3 else 2]) + + # Visit a parse tree produced by STIXPatternParser#propTestIsSuperset. + def visitPropTestIsSuperset(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("IsSupersetComparisonExpression", children[0], children[3 if len(children) > 3 else 2]) + + # Visit a parse tree produced by STIXPatternParser#propTestParen. + def visitPropTestParen(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("ParentheticalExpression", children[1]) + + # Visit a parse tree produced by STIXPatternParser#startStopQualifier. + def visitStartStopQualifier(self, ctx): + children = self.visitChildren(ctx) + return StartStopQualifier(children[1], children[3]) + + # Visit a parse tree produced by STIXPatternParser#withinQualifier. + def visitWithinQualifier(self, ctx): + children = self.visitChildren(ctx) + return WithinQualifier(children[1]) + + # Visit a parse tree produced by STIXPatternParser#repeatedQualifier. + def visitRepeatedQualifier(self, ctx): + children = self.visitChildren(ctx) + return RepeatQualifier(children[1]) + + # Visit a parse tree produced by STIXPatternParser#objectPath. + def visitObjectPath(self, ctx): + children = self.visitChildren(ctx) + flat_list = collapse_lists(children[2:]) + property_path = [] + i = 0 + while i < len(flat_list): + current = flat_list[i] + if i == len(flat_list)-1: + property_path.append(current) + break + next = flat_list[i+1] + if isinstance(next, TerminalNode): + property_path.append(self.instantiate("ListObjectPathComponent", current.property_name, next.getText())) + i += 2 + else: + property_path.append(current) + i += 1 + return self.instantiate("ObjectPath", children[0].getText(), property_path) + + # Visit a parse tree produced by STIXPatternParser#objectType. + def visitObjectType(self, ctx): + children = self.visitChildren(ctx) + return children[0] + + # Visit a parse tree produced by STIXPatternParser#firstPathComponent. + def visitFirstPathComponent(self, ctx): + children = self.visitChildren(ctx) + step = children[0].getText() + # if step.endswith("_ref"): + # return stix2.ReferenceObjectPathComponent(step) + # else: + return self.instantiate("BasicObjectPathComponent", step, False) + + # Visit a parse tree produced by STIXPatternParser#indexPathStep. + def visitIndexPathStep(self, ctx): + children = self.visitChildren(ctx) + return children[1] + + # Visit a parse tree produced by STIXPatternParser#pathStep. + def visitPathStep(self, ctx): + return collapse_lists(self.visitChildren(ctx)) + + # Visit a parse tree produced by STIXPatternParser#keyPathStep. + def visitKeyPathStep(self, ctx): + children = self.visitChildren(ctx) + if isinstance(children[1], StringConstant): + # special case for hashes + return children[1].value + else: + return self.instantiate("BasicObjectPathComponent", children[1].getText(), True) + + # Visit a parse tree produced by STIXPatternParser#setLiteral. + def visitSetLiteral(self, ctx): + children = self.visitChildren(ctx) + return self.instantiate("ListConstant", remove_terminal_nodes(children)) + + # Visit a parse tree produced by STIXPatternParser#primitiveLiteral. + def visitPrimitiveLiteral(self, ctx): + children = self.visitChildren(ctx) + return children[0] + + # Visit a parse tree produced by STIXPatternParser#orderableLiteral. + def visitOrderableLiteral(self, ctx): + children = self.visitChildren(ctx) + return children[0] + + def visitTerminal(self, node): + if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral: + return IntegerConstant(node.getText()) + elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral: + return FloatConstant(node.getText()) + elif node.symbol.type == STIXPatternParser.HexLiteral: + return HexConstant(node.getText(), from_parse_tree=True) + elif node.symbol.type == STIXPatternParser.BinaryLiteral: + return BinaryConstant(node.getText(), from_parse_tree=True) + elif node.symbol.type == STIXPatternParser.StringLiteral: + return StringConstant(node.getText().strip('\''), from_parse_tree=True) + elif node.symbol.type == STIXPatternParser.BoolLiteral: + return BooleanConstant(node.getText()) + elif node.symbol.type == STIXPatternParser.TimestampLiteral: + return TimestampConstant(node.getText()) + else: + return node + + def aggregateResult(self, aggregate, nextResult): + if aggregate: + aggregate.append(nextResult) + elif nextResult: + aggregate = [nextResult] + return aggregate + + +def create_pattern_object(pattern, module_suffix="", module_name=""): + """ + Validates a pattern against the STIX Pattern grammar. Error messages are + returned in a list. The test passed if the returned list is empty. + """ + + start = '' + if isinstance(pattern, six.string_types): + start = pattern[:2] + pattern = InputStream(pattern) + + if not start: + start = pattern.readline()[:2] + pattern.seek(0) + + parseErrListener = STIXPatternErrorListener() + + lexer = STIXPatternLexer(pattern) + # it always adds a console listener by default... remove it. + lexer.removeErrorListeners() + + stream = CommonTokenStream(lexer) + + parser = STIXPatternParser(stream) + parser.buildParseTrees = True + # it always adds a console listener by default... remove it. + parser.removeErrorListeners() + parser.addErrorListener(parseErrListener) + + # To improve error messages, replace "" in the literal + # names with symbolic names. This is a hack, but seemed like + # the simplest workaround. + for i, lit_name in enumerate(parser.literalNames): + if lit_name == u"": + parser.literalNames[i] = parser.symbolicNames[i] + + tree = parser.pattern() + builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) + return builder.visit(tree) diff --git a/stix2/patterns.py b/stix2/patterns.py index 59528bd..9656ff1 100644 --- a/stix2/patterns.py +++ b/stix2/patterns.py @@ -5,6 +5,8 @@ import binascii import datetime import re +import six + from .utils import parse_into_datetime @@ -12,6 +14,14 @@ def escape_quotes_and_backslashes(s): return s.replace(u'\\', u'\\\\').replace(u"'", u"\\'") +def quote_if_needed(x): + if isinstance(x, six.string_types): + if x.find("-") != -1: + if not x.startswith("'"): + return "'" + x + "'" + return x + + class _Constant(object): pass @@ -22,11 +32,13 @@ class StringConstant(_Constant): Args: value (str): string value """ - def __init__(self, value): + + def __init__(self, value, from_parse_tree=False): + self.needs_to_be_quoted = not from_parse_tree self.value = value def __str__(self): - return "'%s'" % escape_quotes_and_backslashes(self.value) + return "'%s'" % (escape_quotes_and_backslashes(self.value) if self.needs_to_be_quoted else self.value) class TimestampConstant(_Constant): @@ -151,7 +163,13 @@ class BinaryConstant(_Constant): Args: value (str): base64 encoded string value """ - def __init__(self, value): + + def __init__(self, value, from_parse_tree=False): + # support with or without a 'b' + if from_parse_tree: + m = re.match("^b'(.+)'$", value) + if m: + value = m.group(1) try: base64.b64decode(value) self.value = value @@ -168,10 +186,16 @@ class HexConstant(_Constant): Args: value (str): hexadecimal value """ - def __init__(self, value): - if not re.match(r'^([a-fA-F0-9]{2})+$', value): - raise ValueError("must contain an even number of hexadecimal characters") - self.value = value + def __init__(self, value, from_parse_tree=False): + # support with or without an 'h' + if not from_parse_tree and re.match('^([a-fA-F0-9]{2})+$', value): + self.value = value + else: + m = re.match("^h'(([a-fA-F0-9]{2})+)'$", value) + if m: + self.value = m.group(1) + else: + raise ValueError("must contain an even number of hexadecimal characters") def __str__(self): return "h'%s'" % self.value @@ -184,10 +208,11 @@ class ListConstant(_Constant): value (list): list of values """ def __init__(self, values): - self.value = values + # handle _Constants or make a _Constant + self.value = [x if isinstance(x, _Constant) else make_constant(x) for x in values] def __str__(self): - return "(" + ", ".join([("%s" % make_constant(x)) for x in self.value]) + ")" + return "(" + ", ".join(["%s" % x for x in self.value]) + ")" def make_constant(value): @@ -228,7 +253,10 @@ class _ObjectPathComponent(object): parse1 = component_name.split("[") return ListObjectPathComponent(parse1[0], parse1[1][:-1]) else: - return BasicObjectPathComponent(component_name) + return BasicObjectPathComponent(component_name, False) + + def __str__(self): + return quote_if_needed(self.property_name) class BasicObjectPathComponent(_ObjectPathComponent): @@ -242,14 +270,11 @@ class BasicObjectPathComponent(_ObjectPathComponent): property_name (str): object property name is_key (bool): is dictionary key, default: False """ - def __init__(self, property_name, is_key=False): + def __init__(self, property_name, is_key): self.property_name = property_name # TODO: set is_key to True if this component is a dictionary key # self.is_key = is_key - def __str__(self): - return self.property_name - class ListObjectPathComponent(_ObjectPathComponent): """List object path component (for an observation or expression) @@ -263,7 +288,7 @@ class ListObjectPathComponent(_ObjectPathComponent): self.index = index def __str__(self): - return "%s[%s]" % (self.property_name, self.index) + return "%s[%s]" % (quote_if_needed(self.property_name), self.index) class ReferenceObjectPathComponent(_ObjectPathComponent): @@ -275,9 +300,6 @@ class ReferenceObjectPathComponent(_ObjectPathComponent): def __init__(self, reference_property_name): self.property_name = reference_property_name - def __str__(self): - return self.property_name - class ObjectPath(object): """Pattern operand object (property) path @@ -295,7 +317,7 @@ class ObjectPath(object): ] def __str__(self): - return "%s:%s" % (self.object_type_name, ".".join(["%s" % x for x in self.property_path])) + return "%s:%s" % (self.object_type_name, ".".join(["%s" % quote_if_needed(x) for x in self.property_path])) def merge(self, other): """Extend the object property with that of the supplied object property path"""