Complete the addition of AST node class override support which

was begun by frank7y.  Added some unit tests.

Also address a design flaw which prevented the AST builder
function create_pattern_object() from honoring module_suffix
and module_name parameters after the first call.
main
Michael Chisholm 2022-04-28 21:10:12 -04:00
parent 1e4b6caf3f
commit e534e41865
5 changed files with 190 additions and 43 deletions

View File

@ -53,22 +53,32 @@ def same_boolean_operator(current_op, op_token):
class STIXPatternVisitorForSTIX2():
classes = {}
def __init__(self, parser_class, module_suffix=None, module_name=None):
self.parser_class = parser_class
if module_suffix and module_name:
self.module_suffix = module_suffix
self.module = importlib.import_module(module_name)
else:
self.module_suffix = self.module = None
def get_class(self, class_name):
if class_name in STIXPatternVisitorForSTIX2.classes:
return STIXPatternVisitorForSTIX2.classes[class_name]
else:
return None
klass = None
if self.module:
class_name_suffix = class_name + "For" + self.module_suffix
member = getattr(self.module, class_name_suffix, None)
if member and inspect.isclass(member):
klass = member
def instantiate(self, klass_name, *args):
klass_to_instantiate = None
if self.module_suffix:
klass_to_instantiate = self.get_class(klass_name + "For" + self.module_suffix)
if not klass_to_instantiate:
# use the classes in python_stix2
klass_to_instantiate = globals()[klass_name]
return klass_to_instantiate(*args)
if not klass:
klass = globals()[class_name]
return klass
def instantiate(self, klass_name, *args, **kwargs):
klass_to_instantiate = self.get_class(klass_name)
return klass_to_instantiate(*args, **kwargs)
# Visit a parse tree produced by STIXPatternParser#pattern.
def visitPattern(self, ctx):
@ -330,12 +340,12 @@ class STIXPatternVisitorForSTIX2():
elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral:
return self.instantiate("FloatConstant", node.getText())
elif node.symbol.type == self.parser_class.HexLiteral:
return HexConstant(node.getText(), from_parse_tree=True)
return self.instantiate("HexConstant", node.getText(), from_parse_tree=True)
elif node.symbol.type == self.parser_class.BinaryLiteral:
return BinaryConstant(node.getText(), from_parse_tree=True)
return self.instantiate("BinaryConstant", node.getText(), from_parse_tree=True)
elif node.symbol.type == self.parser_class.StringLiteral:
if node.getText()[0] == "'" and node.getText()[-1] == "'":
return StringConstant(node.getText()[1:-1], from_parse_tree=True)
return self.instantiate("StringConstant", node.getText()[1:-1], from_parse_tree=True)
else:
raise ParseException("The pattern does not start and end with a single quote")
elif node.symbol.type == self.parser_class.BoolLiteral:
@ -356,37 +366,14 @@ class STIXPatternVisitorForSTIX2():
aggregate = [nextResult]
return aggregate
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser21
super(STIXPatternVisitor21, self).__init__()
pass
class STIXPatternVisitorForSTIX20(STIXPatternVisitorForSTIX2, STIXPatternVisitor20):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser20
super(STIXPatternVisitor20, self).__init__()
pass
def create_pattern_object(pattern, module_suffix="", module_name="", version=DEFAULT_VERSION):
@ -397,10 +384,14 @@ def create_pattern_object(pattern, module_suffix="", module_name="", version=DEF
if version == "2.1":
pattern_class = Pattern21
visitor_class = STIXPatternVisitorForSTIX21
parser_class = STIXPatternParser21
else:
pattern_class = Pattern20
visitor_class = STIXPatternVisitorForSTIX20
parser_class = STIXPatternParser20
pattern_obj = pattern_class(pattern)
builder = visitor_class(module_suffix, module_name)
builder = visitor_class(
parser_class, module_suffix, module_name
)
return pattern_obj.visit(builder)

View File

@ -0,0 +1,21 @@
"""
AST node class overrides for testing the pattern AST builder.
"""
from stix2.patterns import (
EqualityComparisonExpression,
StartStopQualifier,
StringConstant
)
class EqualityComparisonExpressionForTesting(EqualityComparisonExpression):
pass
class StringConstantForTesting(StringConstant):
pass
class StartStopQualifierForTesting(StartStopQualifier):
pass

View File

@ -6,6 +6,8 @@ import pytz
import stix2
from stix2.pattern_visitor import create_pattern_object
import stix2.utils
import stix2.patterns
from .pattern_ast_overrides import *
def test_create_comparison_expression():
@ -587,3 +589,58 @@ def test_parsing_illegal_start_stop_qualified_expression():
def test_list_constant():
patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]", version="2.0")
assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]"
def test_ast_class_override_comp_equals():
patt_ast = create_pattern_object(
"[a:b=1]", "Testing", "stix2.test.v20.pattern_ast_overrides",
version="2.0"
)
assert isinstance(patt_ast, stix2.patterns.ObservationExpression)
assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting)
assert str(patt_ast) == "[a:b = 1]"
def test_ast_class_override_string_constant():
patt_ast = create_pattern_object(
"[a:'b'[1].'c' < 'foo']", "Testing",
"stix2.test.v20.pattern_ast_overrides",
version="2.0"
)
assert isinstance(patt_ast, stix2.patterns.ObservationExpression)
assert isinstance(
patt_ast.operand, stix2.patterns.LessThanComparisonExpression
)
assert isinstance(
patt_ast.operand.lhs.property_path[0].property_name,
str
)
assert isinstance(
patt_ast.operand.lhs.property_path[1].property_name,
str
)
assert isinstance(patt_ast.operand.rhs, StringConstantForTesting)
assert str(patt_ast) == "[a:'b'[1].c < 'foo']"
def test_ast_class_override_startstop_qualifier():
patt_ast = create_pattern_object(
"[a:b=1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'",
"Testing", "stix2.test.v20.pattern_ast_overrides", version="2.0"
)
assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression)
assert isinstance(
patt_ast.observation_expression, stix2.patterns.ObservationExpression
)
assert isinstance(
patt_ast.observation_expression.operand,
EqualityComparisonExpressionForTesting
)
assert isinstance(
patt_ast.qualifier, StartStopQualifierForTesting
)
assert str(patt_ast) == "[a:b = 1] START '1993-01-20T01:33:52.592Z' STOP '2001-08-19T23:50:23.129Z'"

View File

@ -0,0 +1,21 @@
"""
AST node class overrides for testing the pattern AST builder.
"""
from stix2.patterns import (
EqualityComparisonExpression,
StartStopQualifier,
StringConstant
)
class EqualityComparisonExpressionForTesting(EqualityComparisonExpression):
pass
class StringConstantForTesting(StringConstant):
pass
class StartStopQualifierForTesting(StartStopQualifier):
pass

View File

@ -7,6 +7,8 @@ from stix2patterns.exceptions import ParseException
import stix2
from stix2.pattern_visitor import create_pattern_object
import stix2.utils
import stix2.patterns
from .pattern_ast_overrides import *
def test_create_comparison_expression():
@ -747,3 +749,58 @@ def test_parsing_multiple_slashes_quotes():
def test_parse_error():
with pytest.raises(ParseException):
create_pattern_object("[ file: name = 'weirdname]", version="2.1")
def test_ast_class_override_comp_equals():
patt_ast = create_pattern_object(
"[a:b=1]", "Testing", "stix2.test.v21.pattern_ast_overrides",
version="2.1"
)
assert isinstance(patt_ast, stix2.patterns.ObservationExpression)
assert isinstance(patt_ast.operand, EqualityComparisonExpressionForTesting)
assert str(patt_ast) == "[a:b = 1]"
def test_ast_class_override_string_constant():
patt_ast = create_pattern_object(
"[a:'b'[1].'c' < 'foo']", "Testing",
"stix2.test.v21.pattern_ast_overrides",
version="2.1"
)
assert isinstance(patt_ast, stix2.patterns.ObservationExpression)
assert isinstance(
patt_ast.operand, stix2.patterns.LessThanComparisonExpression
)
assert isinstance(
patt_ast.operand.lhs.property_path[0].property_name,
str
)
assert isinstance(
patt_ast.operand.lhs.property_path[1].property_name,
str
)
assert isinstance(patt_ast.operand.rhs, StringConstantForTesting)
assert str(patt_ast) == "[a:'b'[1].c < 'foo']"
def test_ast_class_override_startstop_qualifier():
patt_ast = create_pattern_object(
"[a:b=1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'",
"Testing", "stix2.test.v21.pattern_ast_overrides", version="2.1"
)
assert isinstance(patt_ast, stix2.patterns.QualifiedObservationExpression)
assert isinstance(
patt_ast.observation_expression, stix2.patterns.ObservationExpression
)
assert isinstance(
patt_ast.observation_expression.operand,
EqualityComparisonExpressionForTesting
)
assert isinstance(
patt_ast.qualifier, StartStopQualifierForTesting
)
assert str(patt_ast) == "[a:b = 1] START t'1993-01-20T01:33:52.592Z' STOP t'2001-08-19T23:50:23.129Z'"