diff --git a/stix2/pattern_visitor.py b/stix2/pattern_visitor.py index 36d0c84..42f28b7 100644 --- a/stix2/pattern_visitor.py +++ b/stix2/pattern_visitor.py @@ -53,22 +53,32 @@ def same_boolean_operator(current_op, op_token): class STIXPatternVisitorForSTIX2(): - classes = {} + + def __init__(self, parser_class, module_suffix=None, module_name=None): + self.parser_class = parser_class + + if module_suffix and module_name: + self.module_suffix = module_suffix + self.module = importlib.import_module(module_name) + else: + self.module_suffix = self.module = None def get_class(self, class_name): - if class_name in STIXPatternVisitorForSTIX2.classes: - return STIXPatternVisitorForSTIX2.classes[class_name] - else: - return None + klass = None + if self.module: + class_name_suffix = class_name + "For" + self.module_suffix + member = getattr(self.module, class_name_suffix, None) + if member and inspect.isclass(member): + klass = member - def instantiate(self, klass_name, *args): - klass_to_instantiate = None - if self.module_suffix: - klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix) - if not klass_to_instantiate: - # use the classes in python_stix2 - klass_to_instantiate = globals()[klass_name] - return klass_to_instantiate(*args) + if not klass: + klass = globals()[class_name] + + return klass + + def instantiate(self, klass_name, *args, **kwargs): + klass_to_instantiate = self.get_class(klass_name) + return klass_to_instantiate(*args, **kwargs) # Visit a parse tree produced by STIXPatternParser#pattern. def visitPattern(self, ctx): @@ -330,12 +340,12 @@ class STIXPatternVisitorForSTIX2(): elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral: return self.instantiate("FloatConstant", node.getText()) elif node.symbol.type == self.parser_class.HexLiteral: - return HexConstant(node.getText(), from_parse_tree=True) + return self.instantiate("HexConstant", node.getText(), from_parse_tree=True) elif node.symbol.type == self.parser_class.BinaryLiteral: - return BinaryConstant(node.getText(), from_parse_tree=True) + return self.instantiate("BinaryConstant", node.getText(), from_parse_tree=True) elif node.symbol.type == self.parser_class.StringLiteral: if node.getText()[0] == "'" and node.getText()[-1] == "'": - return StringConstant(node.getText()[1:-1], from_parse_tree=True) + return self.instantiate("StringConstant", node.getText()[1:-1], from_parse_tree=True) else: raise ParseException("The pattern does not start and end with a single quote") elif node.symbol.type == self.parser_class.BoolLiteral: @@ -356,37 +366,14 @@ class STIXPatternVisitorForSTIX2(): aggregate = [nextResult] return aggregate + # This class defines a complete generic visitor for a parse tree produced by STIXPatternParser. class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21): - classes = {} - - def __init__(self, module_suffix, module_name): - if module_suffix and module_name: - self.module_suffix = module_suffix - if not STIXPatternVisitorForSTIX2.classes: - module = importlib.import_module(module_name) - for k, c in inspect.getmembers(module, inspect.isclass): - STIXPatternVisitorForSTIX2.classes[k] = c - else: - self.module_suffix = None - self.parser_class = STIXPatternParser21 - super(STIXPatternVisitor21, self).__init__() + pass class STIXPatternVisitorForSTIX20(STIXPatternVisitorForSTIX2, STIXPatternVisitor20): - classes = {} - - def __init__(self, module_suffix, module_name): - if module_suffix and module_name: - self.module_suffix = module_suffix - if not STIXPatternVisitorForSTIX2.classes: - module = importlib.import_module(module_name) - for k, c in inspect.getmembers(module, inspect.isclass): - STIXPatternVisitorForSTIX2.classes[k] = c - else: - self.module_suffix = None - self.parser_class = STIXPatternParser20 - super(STIXPatternVisitor20, self).__init__() + pass def create_pattern_object(pattern, module_suffix="", module_name="", version=DEFAULT_VERSION): @@ -397,10 +384,14 @@ def create_pattern_object(pattern, module_suffix="", module_name="", version=DEF if version == "2.1": pattern_class = Pattern21 visitor_class = STIXPatternVisitorForSTIX21 + parser_class = STIXPatternParser21 else: pattern_class = Pattern20 visitor_class = STIXPatternVisitorForSTIX20 + parser_class = STIXPatternParser20 pattern_obj = pattern_class(pattern) - builder = visitor_class(module_suffix, module_name) + builder = visitor_class( + parser_class, module_suffix, module_name + ) return pattern_obj.visit(builder) diff --git a/stix2/test/v20/pattern_ast_overrides.py b/stix2/test/v20/pattern_ast_overrides.py new file mode 100644 index 0000000..0747c2a --- /dev/null +++ b/stix2/test/v20/pattern_ast_overrides.py @@ -0,0 +1,21 @@ +""" +AST node class overrides for testing the pattern AST builder. +""" +from stix2.patterns import ( + EqualityComparisonExpression, + StartStopQualifier, + StringConstant +) + + +class EqualityComparisonExpressionForTesting(EqualityComparisonExpression): + pass + + +class StringConstantForTesting(StringConstant): + pass + + +class StartStopQualifierForTesting(StartStopQualifier): + pass + diff --git a/stix2/test/v20/test_pattern_expressions.py b/stix2/test/v20/test_pattern_expressions.py index 4d0073a..0af541b 100644 --- a/stix2/test/v20/test_pattern_expressions.py +++ b/stix2/test/v20/test_pattern_expressions.py @@ -6,6 +6,8 @@ import pytz import stix2 from stix2.pattern_visitor import create_pattern_object import stix2.utils +import stix2.patterns +from .pattern_ast_overrides import * def test_create_comparison_expression(): @@ -587,3 +589,58 @@ def test_parsing_illegal_start_stop_qualified_expression(): def test_list_constant(): patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]", version="2.0") assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]" + + +def test_ast_class_override_comp_equals(): + patt_ast = create_pattern_object( + "[a:b=1]", "Testing", "stix2.test.v20.pattern_ast_overrides", + version="2.0" + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting) + assert str(patt_ast) == "[a:b = 1]" + + +def test_ast_class_override_string_constant(): + patt_ast = create_pattern_object( + "[a:'b'[1].'c' < 'foo']", "Testing", + "stix2.test.v20.pattern_ast_overrides", + version="2.0" + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance( + patt_ast.operand, stix2.patterns.LessThanComparisonExpression + ) + assert isinstance( + patt_ast.operand.lhs.property_path[0].property_name, + str + ) + assert isinstance( + patt_ast.operand.lhs.property_path[1].property_name, + str + ) + assert isinstance(patt_ast.operand.rhs, StringConstantForTesting) + + assert str(patt_ast) == "[a:'b'[1].c < 'foo']" + + +def test_ast_class_override_startstop_qualifier(): + patt_ast = create_pattern_object( + "[a:b=1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'", + "Testing", "stix2.test.v20.pattern_ast_overrides", version="2.0" + ) + + assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression) + assert isinstance( + patt_ast.observation_expression, stix2.patterns.ObservationExpression + ) + assert isinstance( + patt_ast.observation_expression.operand, + EqualityComparisonExpressionForTesting + ) + assert isinstance( + patt_ast.qualifier, StartStopQualifierForTesting + ) + assert str(patt_ast) == "[a:b = 1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'" diff --git a/stix2/test/v21/pattern_ast_overrides.py b/stix2/test/v21/pattern_ast_overrides.py new file mode 100644 index 0000000..0747c2a --- /dev/null +++ b/stix2/test/v21/pattern_ast_overrides.py @@ -0,0 +1,21 @@ +""" +AST node class overrides for testing the pattern AST builder. +""" +from stix2.patterns import ( + EqualityComparisonExpression, + StartStopQualifier, + StringConstant +) + + +class EqualityComparisonExpressionForTesting(EqualityComparisonExpression): + pass + + +class StringConstantForTesting(StringConstant): + pass + + +class StartStopQualifierForTesting(StartStopQualifier): + pass + diff --git a/stix2/test/v21/test_pattern_expressions.py b/stix2/test/v21/test_pattern_expressions.py index d7afe5c..feea22e 100644 --- a/stix2/test/v21/test_pattern_expressions.py +++ b/stix2/test/v21/test_pattern_expressions.py @@ -7,6 +7,8 @@ from stix2patterns.exceptions import ParseException import stix2 from stix2.pattern_visitor import create_pattern_object import stix2.utils +import stix2.patterns +from .pattern_ast_overrides import * def test_create_comparison_expression(): @@ -747,3 +749,58 @@ def test_parsing_multiple_slashes_quotes(): def test_parse_error(): with pytest.raises(ParseException): create_pattern_object("[ file: name = 'weirdname]", version="2.1") + + +def test_ast_class_override_comp_equals(): + patt_ast = create_pattern_object( + "[a:b=1]", "Testing", "stix2.test.v21.pattern_ast_overrides", + version="2.1" + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting) + assert str(patt_ast) == "[a:b = 1]" + + +def test_ast_class_override_string_constant(): + patt_ast = create_pattern_object( + "[a:'b'[1].'c' < 'foo']", "Testing", + "stix2.test.v21.pattern_ast_overrides", + version="2.1" + ) + + assert isinstance(patt_ast, stix2.patterns.ObservationExpression) + assert isinstance( + patt_ast.operand, stix2.patterns.LessThanComparisonExpression + ) + assert isinstance( + patt_ast.operand.lhs.property_path[0].property_name, + str + ) + assert isinstance( + patt_ast.operand.lhs.property_path[1].property_name, + str + ) + assert isinstance(patt_ast.operand.rhs, StringConstantForTesting) + + assert str(patt_ast) == "[a:'b'[1].c < 'foo']" + + +def test_ast_class_override_startstop_qualifier(): + patt_ast = create_pattern_object( + "[a:b=1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'", + "Testing", "stix2.test.v21.pattern_ast_overrides", version="2.1" + ) + + assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression) + assert isinstance( + patt_ast.observation_expression, stix2.patterns.ObservationExpression + ) + assert isinstance( + patt_ast.observation_expression.operand, + EqualityComparisonExpressionForTesting + ) + assert isinstance( + patt_ast.qualifier, StartStopQualifierForTesting + ) + assert str(patt_ast) == "[a:b = 1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'"