diff --git a/stix2/pattern_visitor.py b/stix2/pattern_visitor.py index 6ac3e98..51e309f 100644 --- a/stix2/pattern_visitor.py +++ b/stix2/pattern_visitor.py @@ -2,11 +2,16 @@ import importlib import inspect from stix2patterns.exceptions import ParseException -from stix2patterns.grammars.STIXPatternParser import ( - STIXPatternParser, TerminalNode, -) +from stix2patterns.grammars.STIXPatternParser import TerminalNode + +from stix2patterns.v20.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser20 +from stix2patterns.v21.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser21 + +from stix2patterns.v20.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor20 +from stix2patterns.v21.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor21 from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor -from stix2patterns.v20.pattern import Pattern +from stix2patterns.v20.pattern import Pattern as Pattern20 +from stix2patterns.v21.pattern import Pattern as Pattern21 from .patterns import * from .patterns import _BooleanExpression @@ -32,23 +37,12 @@ def remove_terminal_nodes(parse_tree_nodes): return values -# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. -class STIXPatternVisitorForSTIX2(STIXPatternVisitor): + +class STIXPatternVisitorForSTIX2(): classes = {} - def __init__(self, module_suffix, module_name): - if module_suffix and module_name: - self.module_suffix = module_suffix - if not STIXPatternVisitorForSTIX2.classes: - module = importlib.import_module(module_name) - for k, c in inspect.getmembers(module, inspect.isclass): - STIXPatternVisitorForSTIX2.classes[k] = c - else: - self.module_suffix = None - super(STIXPatternVisitor, self).__init__() - def get_class(self, class_name): if class_name in STIXPatternVisitorForSTIX2.classes: return STIXPatternVisitorForSTIX2.classes[class_name] @@ -147,7 +141,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): def visitPropTestEqual(self, ctx): children = self.visitChildren(ctx) operator = children[1].symbol.type - negated = operator != STIXPatternParser.EQ + negated = operator != self.parser_class.EQ return self.instantiate( "EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2], negated, @@ -157,22 +151,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): def visitPropTestOrder(self, ctx): children = self.visitChildren(ctx) operator = children[1].symbol.type - if operator == STIXPatternParser.GT: + if operator == self.parser_class.GT: return self.instantiate( "GreaterThanComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False, ) - elif operator == STIXPatternParser.LT: + elif operator == self.parser_class.LT: return self.instantiate( "LessThanComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False, ) - elif operator == STIXPatternParser.GE: + elif operator == self.parser_class.GE: return self.instantiate( "GreaterThanEqualComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False, ) - elif operator == STIXPatternParser.LE: + elif operator == self.parser_class.LE: return self.instantiate( "LessThanEqualComparisonExpression", children[0], children[3 if len(children) > 3 else 2], False, @@ -294,22 +288,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): return children[0] def visitTerminal(self, node): - if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral: + if node.symbol.type == self.parser_class.IntPosLiteral or node.symbol.type == self.parser_class.IntNegLiteral: return IntegerConstant(node.getText()) - elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral: + elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral: return FloatConstant(node.getText()) - elif node.symbol.type == STIXPatternParser.HexLiteral: + elif node.symbol.type == self.parser_class.HexLiteral: return HexConstant(node.getText(), from_parse_tree=True) - elif node.symbol.type == STIXPatternParser.BinaryLiteral: + elif node.symbol.type == self.parser_class.BinaryLiteral: return BinaryConstant(node.getText(), from_parse_tree=True) - elif node.symbol.type == STIXPatternParser.StringLiteral: + elif node.symbol.type == self.parser_class.StringLiteral: if node.getText()[0] == "'" and node.getText()[-1] == "'": return StringConstant(node.getText()[1:-1], from_parse_tree=True) else: raise ParseException("The pattern does not start and end with a single quote") - elif node.symbol.type == STIXPatternParser.BoolLiteral: + elif node.symbol.type == self.parser_class.BoolLiteral: return BooleanConstant(node.getText()) - elif node.symbol.type == STIXPatternParser.TimestampLiteral: + elif node.symbol.type == self.parser_class.TimestampLiteral: return TimestampConstant(node.getText()) else: return node @@ -321,12 +315,44 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): aggregate = [nextResult] return aggregate +# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. +class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21): + classes = {} -def create_pattern_object(pattern, module_suffix="", module_name=""): + def __init__(self, module_suffix, module_name): + if module_suffix and module_name: + self.module_suffix = module_suffix + if not STIXPatternVisitorForSTIX2.classes: + module = importlib.import_module(module_name) + for k, c in inspect.getmembers(module, inspect.isclass): + STIXPatternVisitorForSTIX2.classes[k] = c + else: + self.module_suffix = None + self.parser_class = STIXPatternParser21 + super(STIXPatternVisitor21, self).__init__() + + +class STIXPatternVisitorForSTIX20(STIXPatternVisitor20, STIXPatternVisitorForSTIX2): + classes = {} + + def __init__(self, module_suffix, module_name): + if module_suffix and module_name: + self.module_suffix = module_suffix + if not STIXPatternVisitorForSTIX2.classes: + module = importlib.import_module(module_name) + for k, c in inspect.getmembers(module, inspect.isclass): + STIXPatternVisitorForSTIX2.classes[k] = c + else: + self.module_suffix = None + self.parser_class = STIXPatternParser20 + super(STIXPatternVisitor20, self).__init__() + + +def create_pattern_object(pattern, module_suffix="", module_name="", version="2.1"): """ Create a STIX pattern AST from a pattern string. """ - pattern_obj = Pattern(pattern) - builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) + pattern_obj = Pattern21(pattern) if version == "2.1" else Pattern20(pattern) + builder = STIXPatternVisitorForSTIX21(module_suffix, module_name) if version == "2.1" else STIXPatternVisitorForSTIX20(module_suffix, module_name) return pattern_obj.visit(builder) diff --git a/stix2/test/v21/test_pattern_expressions.py b/stix2/test/v21/test_pattern_expressions.py index 0c298f8..502d7d3 100644 --- a/stix2/test/v21/test_pattern_expressions.py +++ b/stix2/test/v21/test_pattern_expressions.py @@ -175,20 +175,34 @@ def test_greater_than(): assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]" +def test_parsing_greater_than(): + patt_obj = create_pattern_object("[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]") + assert str(patt_obj) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]" + + def test_less_than(): exp = stix2.LessThanComparisonExpression("file:size", 1024) assert str(exp) == "file:size < 1024" +def test_parsing_less_than(): + patt_obj = create_pattern_object("[file:size < 1024]") + assert str(patt_obj) == "[file:size < 1024]" + + def test_greater_than_or_equal(): exp = stix2.GreaterThanEqualComparisonExpression( "file:size", 1024, ) - assert str(exp) == "file:size >= 1024" +def test_parsing_greater_than_or_equal(): + patt_obj = create_pattern_object("[file:size >= 1024]") + assert str(patt_obj) == "[file:size >= 1024]" + + def test_less_than_or_equal(): exp = stix2.LessThanEqualComparisonExpression( "file:size", @@ -197,6 +211,11 @@ def test_less_than_or_equal(): assert str(exp) == "file:size <= 1024" +def test_parsing_less_than_or_equal(): + patt_obj = create_pattern_object("[file:size <= 1024]") + assert str(patt_obj) == "[file:size <= 1024]" + + def test_not(): exp = stix2.LessThanComparisonExpression( "file:size", @@ -257,6 +276,67 @@ def test_and_observable_expression(): assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa +def test_parsing_and_observable_expression(): + exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']") + assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" + + +def test_or_observable_expression(): + exp1 = stix2.AndBooleanExpression([ + stix2.EqualityComparisonExpression( + "user-account:account_type", + "unix", + ), + stix2.EqualityComparisonExpression( + "user-account:user_id", + stix2.StringConstant("1007"), + ), + stix2.EqualityComparisonExpression( + "user-account:account_login", + "Peter", + ), + ]) + exp2 = stix2.AndBooleanExpression([ + stix2.EqualityComparisonExpression( + "user-account:account_type", + "unix", + ), + stix2.EqualityComparisonExpression( + "user-account:user_id", + stix2.StringConstant("1008"), + ), + stix2.EqualityComparisonExpression( + "user-account:account_login", + "Paul", + ), + ]) + exp3 = stix2.AndBooleanExpression([ + stix2.EqualityComparisonExpression( + "user-account:account_type", + "unix", + ), + stix2.EqualityComparisonExpression( + "user-account:user_id", + stix2.StringConstant("1009"), + ), + stix2.EqualityComparisonExpression( + "user-account:account_login", + "Mary", + ), + ]) + exp = stix2.OrObservationExpression([ + stix2.ObservationExpression(exp1), + stix2.ObservationExpression(exp2), + stix2.ObservationExpression(exp3), + ]) + assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa + + +def test_parsing_or_observable_expression(): + exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']") # noqa + assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa + + def test_invalid_and_observable_expression(): with pytest.raises(ValueError): stix2.AndBooleanExpression([ @@ -286,6 +366,11 @@ def test_hex(): assert str(exp) == "[file:mime_type = 'image/bmp' AND file:magic_number_hex = h'ffd8']" +def test_parsing_hex(): + patt_obj = create_pattern_object("[file:magic_number_hex = h'ffd8']") + assert str(patt_obj) == "[file:magic_number_hex = h'ffd8']" + + def test_multiple_qualifiers(): exp_and = stix2.AndBooleanExpression([ stix2.EqualityComparisonExpression( @@ -334,6 +419,11 @@ def test_binary(): assert str(exp) == "artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q='" +def test_parsing_binary(): + patt_obj = create_pattern_object("[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']") + assert str(patt_obj) == "[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']" + + def test_list(): exp = stix2.InComparisonExpression( "process:name", @@ -499,7 +589,7 @@ def test_parsing_comparison_expression(): assert str(patt_obj) == "[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']" -def test_parsing_qualified_expression(): +def test_parsing_repeat_and_within_qualified_expression(): patt_obj = create_pattern_object( "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS", ) @@ -508,11 +598,25 @@ def test_parsing_qualified_expression(): ) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS" +def test_parsing_start_stop_qualified_expression(): + patt_obj = create_pattern_object( + "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'", + ) + assert str( + patt_obj, + ) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'" + + def test_list_constant(): patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]") assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]" +def test_parsing_boolean(): + patt_obj = create_pattern_object("[network-traffic:is_active = true]") + assert str(patt_obj) == "[network-traffic:is_active = true]" + + def test_parsing_multiple_slashes_quotes(): patt_obj = create_pattern_object("[ file:name = 'weird_name\\'' ]") assert str(patt_obj) == "[file:name = 'weird_name\\'']"