diff --git a/stix2/pattern_visitor.py b/stix2/pattern_visitor.py index c4a616b..651cfcf 100644 --- a/stix2/pattern_visitor.py +++ b/stix2/pattern_visitor.py @@ -53,22 +53,32 @@ def same_boolean_operator(current_op, op_token): class STIXPatternVisitorForSTIX2(): - classes = {} + + def __init__(self, parser_class, module_suffix=None, module_name=None): + self.parser_class = parser_class + + if module_suffix and module_name: + self.module_suffix = module_suffix + self.module = importlib.import_module(module_name) + else: + self.module_suffix = self.module = None def get_class(self, class_name): - if class_name in STIXPatternVisitorForSTIX2.classes: - return STIXPatternVisitorForSTIX2.classes[class_name] - else: - return None + klass = None + if self.module: + class_name_suffix = class_name + "For" + self.module_suffix + member = getattr(self.module, class_name_suffix, None) + if member and inspect.isclass(member): + klass = member - def instantiate(self, klass_name, *args): - klass_to_instantiate = None - if self.module_suffix: - klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix) - if not klass_to_instantiate: - # use the classes in python_stix2 - klass_to_instantiate = globals()[klass_name] - return klass_to_instantiate(*args) + if not klass: + klass = globals()[class_name] + + return klass + + def instantiate(self, klass_name, *args, **kwargs): + klass_to_instantiate = self.get_class(klass_name) + return klass_to_instantiate(*args, **kwargs) # Visit a parse tree produced by STIXPatternParser#pattern. def visitPattern(self, ctx): @@ -81,7 +91,7 @@ class STIXPatternVisitorForSTIX2(): if len(children) == 1: return children[0] else: - return FollowedByObservationExpression([children[0], children[2]]) + return self.instantiate("FollowedByObservationExpression", [children[0], children[2]]) # Visit a parse tree produced by STIXPatternParser#observationExpressionOr. def visitObservationExpressionOr(self, ctx): @@ -231,17 +241,17 @@ class STIXPatternVisitorForSTIX2(): if not check_for_valid_timetamp_syntax(children[3].value): raise (ValueError("Stop time is not a legal timestamp")) - return StartStopQualifier(children[1], children[3]) + return self.instantiate("StartStopQualifier", children[1], children[3]) # Visit a parse tree produced by STIXPatternParser#withinQualifier. def visitWithinQualifier(self, ctx): children = self.visitChildren(ctx) - return WithinQualifier(children[1]) + return self.instantiate("WithinQualifier", children[1]) # Visit a parse tree produced by STIXPatternParser#repeatedQualifier. def visitRepeatedQualifier(self, ctx): children = self.visitChildren(ctx) - return RepeatQualifier(children[1]) + return self.instantiate("RepeatQualifier", children[1]) # Visit a parse tree produced by STIXPatternParser#objectPath. def visitObjectPath(self, ctx): @@ -326,26 +336,26 @@ class STIXPatternVisitorForSTIX2(): def visitTerminal(self, node): if node.symbol.type == self.parser_class.IntPosLiteral or node.symbol.type == self.parser_class.IntNegLiteral: - return IntegerConstant(node.getText()) + return self.instantiate("IntegerConstant", node.getText()) elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral: - return FloatConstant(node.getText()) + return self.instantiate("FloatConstant", node.getText()) elif node.symbol.type == self.parser_class.HexLiteral: - return HexConstant(node.getText(), from_parse_tree=True) + return self.instantiate("HexConstant", node.getText(), from_parse_tree=True) elif node.symbol.type == self.parser_class.BinaryLiteral: - return BinaryConstant(node.getText(), from_parse_tree=True) + return self.instantiate("BinaryConstant", node.getText(), from_parse_tree=True) elif node.symbol.type == self.parser_class.StringLiteral: if node.getText()[0] == "'" and node.getText()[-1] == "'": - return StringConstant(node.getText()[1:-1], from_parse_tree=True) + return self.instantiate("StringConstant", node.getText()[1:-1], from_parse_tree=True) else: raise ParseException("The pattern does not start and end with a single quote") elif node.symbol.type == self.parser_class.BoolLiteral: - return BooleanConstant(node.getText()) + return self.instantiate("BooleanConstant", node.getText()) elif node.symbol.type == self.parser_class.TimestampLiteral: value = node.getText() # STIX 2.1 uses a special timestamp literal syntax if value.startswith("t"): value = value[2:-1] - return TimestampConstant(value) + return self.instantiate("TimestampConstant", value) else: return node @@ -356,37 +366,14 @@ class STIXPatternVisitorForSTIX2(): aggregate = [nextResult] return aggregate + # This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21): - classes = {} - - def __init__(self, module_suffix, module_name): - if module_suffix and module_name: - self.module_suffix = module_suffix - if not STIXPatternVisitorForSTIX2.classes: - module = importlib.import_module(module_name) - for k, c in inspect.getmembers(module, inspect.isclass): - STIXPatternVisitorForSTIX2.classes[k] = c - else: - self.module_suffix = None - self.parser_class = STIXPatternParser21 - super(STIXPatternVisitor21, self).__init__() + pass class STIXPatternVisitorForSTIX20(STIXPatternVisitorForSTIX2, STIXPatternVisitor20): - classes = {} - - def __init__(self, module_suffix, module_name): - if module_suffix and module_name: - self.module_suffix = module_suffix - if not STIXPatternVisitorForSTIX2.classes: - module = importlib.import_module(module_name) - for k, c in inspect.getmembers(module, inspect.isclass): - STIXPatternVisitorForSTIX2.classes[k] = c - else: - self.module_suffix = None - self.parser_class = STIXPatternParser20 - super(STIXPatternVisitor20, self).__init__() + pass def create_pattern_object(pattern, module_suffix="", module_name="", version=DEFAULT_VERSION): @@ -397,10 +384,14 @@ def create_pattern_object(pattern, module_suffix="", module_name="", version=DEF if version == "2.1": pattern_class = Pattern21 visitor_class = STIXPatternVisitorForSTIX21 + parser_class = STIXPatternParser21 else: pattern_class = Pattern20 visitor_class = STIXPatternVisitorForSTIX20 + parser_class = STIXPatternParser20 pattern_obj = pattern_class(pattern) - builder = visitor_class(module_suffix, module_name) + builder = visitor_class( + parser_class, module_suffix, module_name, + ) return pattern_obj.visit(builder) diff --git a/stix2/test/v20/pattern_ast_overrides.py b/stix2/test/v20/pattern_ast_overrides.py new file mode 100644 index 0000000..16473ba --- /dev/null +++ b/stix2/test/v20/pattern_ast_overrides.py @@ -0,0 +1,18 @@ +""" +AST node class overrides for testing the pattern AST builder. +""" +from stix2.patterns import ( + EqualityComparisonExpression, StartStopQualifier, StringConstant, +) + + +class EqualityComparisonExpressionForTesting(EqualityComparisonExpression): + pass + + +class StringConstantForTesting(StringConstant): + pass + + +class StartStopQualifierForTesting(StartStopQualifier): + pass diff --git a/stix2/test/v20/test_pattern_expressions.py b/stix2/test/v20/test_pattern_expressions.py index 4d0073a..9579ef4 100644 --- a/stix2/test/v20/test_pattern_expressions.py +++ b/stix2/test/v20/test_pattern_expressions.py @@ -5,8 +5,13 @@ import pytz import stix2 from stix2.pattern_visitor import create_pattern_object +import stix2.patterns import stix2.utils +# flake8 does not approve of star imports. +# flake8: noqa: F405 +from .pattern_ast_overrides import * + def test_create_comparison_expression(): exp = stix2.EqualityComparisonExpression( @@ -587,3 +592,58 @@ def test_parsing_illegal_start_stop_qualified_expression(): def test_list_constant(): patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]", version="2.0") assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]" + + +def test_ast_class_override_comp_equals(): + patt_ast = create_pattern_object( + "[a:b=1]", "Testing", "stix2.test.v20.pattern_ast_overrides", + version="2.0", + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting) + assert str(patt_ast) == "[a:b = 1]" + + +def test_ast_class_override_string_constant(): + patt_ast = create_pattern_object( + "[a:'b'[1].'c' < 'foo']", "Testing", + "stix2.test.v20.pattern_ast_overrides", + version="2.0", + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance( + patt_ast.operand, stix2.patterns.LessThanComparisonExpression, + ) + assert isinstance( + patt_ast.operand.lhs.property_path[0].property_name, + str, + ) + assert isinstance( + patt_ast.operand.lhs.property_path[1].property_name, + str, + ) + assert isinstance(patt_ast.operand.rhs, StringConstantForTesting) + + assert str(patt_ast) == "[a:'b'[1].c < 'foo']" + + +def test_ast_class_override_startstop_qualifier(): + patt_ast = create_pattern_object( + "[a:b=1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'", + "Testing", "stix2.test.v20.pattern_ast_overrides", version="2.0", + ) + + assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression) + assert isinstance( + patt_ast.observation_expression, stix2.patterns.ObservationExpression, + ) + assert isinstance( + patt_ast.observation_expression.operand, + EqualityComparisonExpressionForTesting, + ) + assert isinstance( + patt_ast.qualifier, StartStopQualifierForTesting, + ) + assert str(patt_ast) == "[a:b = 1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'" diff --git a/stix2/test/v21/pattern_ast_overrides.py b/stix2/test/v21/pattern_ast_overrides.py new file mode 100644 index 0000000..16473ba --- /dev/null +++ b/stix2/test/v21/pattern_ast_overrides.py @@ -0,0 +1,18 @@ +""" +AST node class overrides for testing the pattern AST builder. +""" +from stix2.patterns import ( + EqualityComparisonExpression, StartStopQualifier, StringConstant, +) + + +class EqualityComparisonExpressionForTesting(EqualityComparisonExpression): + pass + + +class StringConstantForTesting(StringConstant): + pass + + +class StartStopQualifierForTesting(StartStopQualifier): + pass diff --git a/stix2/test/v21/test_pattern_expressions.py b/stix2/test/v21/test_pattern_expressions.py index d7afe5c..d171879 100644 --- a/stix2/test/v21/test_pattern_expressions.py +++ b/stix2/test/v21/test_pattern_expressions.py @@ -6,8 +6,13 @@ from stix2patterns.exceptions import ParseException import stix2 from stix2.pattern_visitor import create_pattern_object +import stix2.patterns import stix2.utils +# flake8 does not approve of star imports. +# flake8: noqa: F405 +from .pattern_ast_overrides import * + def test_create_comparison_expression(): exp = stix2.EqualityComparisonExpression( @@ -747,3 +752,58 @@ def test_parsing_multiple_slashes_quotes(): def test_parse_error(): with pytest.raises(ParseException): create_pattern_object("[ file: name = 'weirdname]", version="2.1") + + +def test_ast_class_override_comp_equals(): + patt_ast = create_pattern_object( + "[a:b=1]", "Testing", "stix2.test.v21.pattern_ast_overrides", + version="2.1", + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting) + assert str(patt_ast) == "[a:b = 1]" + + +def test_ast_class_override_string_constant(): + patt_ast = create_pattern_object( + "[a:'b'[1].'c' < 'foo']", "Testing", + "stix2.test.v21.pattern_ast_overrides", + version="2.1", + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance( + patt_ast.operand, stix2.patterns.LessThanComparisonExpression, + ) + assert isinstance( + patt_ast.operand.lhs.property_path[0].property_name, + str, + ) + assert isinstance( + patt_ast.operand.lhs.property_path[1].property_name, + str, + ) + assert isinstance(patt_ast.operand.rhs, StringConstantForTesting) + + assert str(patt_ast) == "[a:'b'[1].c < 'foo']" + + +def test_ast_class_override_startstop_qualifier(): + patt_ast = create_pattern_object( + "[a:b=1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'", + "Testing", "stix2.test.v21.pattern_ast_overrides", version="2.1", + ) + + assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression) + assert isinstance( + patt_ast.observation_expression, stix2.patterns.ObservationExpression, + ) + assert isinstance( + patt_ast.observation_expression.operand, + EqualityComparisonExpressionForTesting, + ) + assert isinstance( + patt_ast.qualifier, StartStopQualifierForTesting, + ) + assert str(patt_ast) == "[a:b = 1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'"