more pattern tests

master
Rich Piazza 2020-03-27 11:22:00 -04:00
parent e8035863b8
commit 202111acdf
2 changed files with 164 additions and 34 deletions

View File

@ -2,11 +2,16 @@ import importlib
import inspect import inspect
from stix2patterns.exceptions import ParseException from stix2patterns.exceptions import ParseException
from stix2patterns.grammars.STIXPatternParser import ( from stix2patterns.grammars.STIXPatternParser import TerminalNode
STIXPatternParser, TerminalNode,
) from stix2patterns.v20.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser20
from stix2patterns.v21.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser21
from stix2patterns.v20.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor20
from stix2patterns.v21.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor21
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
from stix2patterns.v20.pattern import Pattern from stix2patterns.v20.pattern import Pattern as Pattern20
from stix2patterns.v21.pattern import Pattern as Pattern21
from .patterns import * from .patterns import *
from .patterns import _BooleanExpression from .patterns import _BooleanExpression
@ -32,23 +37,12 @@ def remove_terminal_nodes(parse_tree_nodes):
return values return values
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
class STIXPatternVisitorForSTIX2():
classes = {} classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
super(STIXPatternVisitor, self).__init__()
def get_class(self, class_name): def get_class(self, class_name):
if class_name in STIXPatternVisitorForSTIX2.classes: if class_name in STIXPatternVisitorForSTIX2.classes:
return STIXPatternVisitorForSTIX2.classes[class_name] return STIXPatternVisitorForSTIX2.classes[class_name]
@ -147,7 +141,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def visitPropTestEqual(self, ctx): def visitPropTestEqual(self, ctx):
children = self.visitChildren(ctx) children = self.visitChildren(ctx)
operator = children[1].symbol.type operator = children[1].symbol.type
negated = operator != STIXPatternParser.EQ negated = operator != self.parser_class.EQ
return self.instantiate( return self.instantiate(
"EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2], "EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2],
negated, negated,
@ -157,22 +151,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def visitPropTestOrder(self, ctx): def visitPropTestOrder(self, ctx):
children = self.visitChildren(ctx) children = self.visitChildren(ctx)
operator = children[1].symbol.type operator = children[1].symbol.type
if operator == STIXPatternParser.GT: if operator == self.parser_class.GT:
return self.instantiate( return self.instantiate(
"GreaterThanComparisonExpression", children[0], "GreaterThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False, children[3 if len(children) > 3 else 2], False,
) )
elif operator == STIXPatternParser.LT: elif operator == self.parser_class.LT:
return self.instantiate( return self.instantiate(
"LessThanComparisonExpression", children[0], "LessThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False, children[3 if len(children) > 3 else 2], False,
) )
elif operator == STIXPatternParser.GE: elif operator == self.parser_class.GE:
return self.instantiate( return self.instantiate(
"GreaterThanEqualComparisonExpression", children[0], "GreaterThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False, children[3 if len(children) > 3 else 2], False,
) )
elif operator == STIXPatternParser.LE: elif operator == self.parser_class.LE:
return self.instantiate( return self.instantiate(
"LessThanEqualComparisonExpression", children[0], "LessThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False, children[3 if len(children) > 3 else 2], False,
@ -294,22 +288,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
return children[0] return children[0]
def visitTerminal(self, node): def visitTerminal(self, node):
if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral: if node.symbol.type == self.parser_class.IntPosLiteral or node.symbol.type == self.parser_class.IntNegLiteral:
return IntegerConstant(node.getText()) return IntegerConstant(node.getText())
elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral: elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral:
return FloatConstant(node.getText()) return FloatConstant(node.getText())
elif node.symbol.type == STIXPatternParser.HexLiteral: elif node.symbol.type == self.parser_class.HexLiteral:
return HexConstant(node.getText(), from_parse_tree=True) return HexConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == STIXPatternParser.BinaryLiteral: elif node.symbol.type == self.parser_class.BinaryLiteral:
return BinaryConstant(node.getText(), from_parse_tree=True) return BinaryConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == STIXPatternParser.StringLiteral: elif node.symbol.type == self.parser_class.StringLiteral:
if node.getText()[0] == "'" and node.getText()[-1] == "'": if node.getText()[0] == "'" and node.getText()[-1] == "'":
return StringConstant(node.getText()[1:-1], from_parse_tree=True) return StringConstant(node.getText()[1:-1], from_parse_tree=True)
else: else:
raise ParseException("The pattern does not start and end with a single quote") raise ParseException("The pattern does not start and end with a single quote")
elif node.symbol.type == STIXPatternParser.BoolLiteral: elif node.symbol.type == self.parser_class.BoolLiteral:
return BooleanConstant(node.getText()) return BooleanConstant(node.getText())
elif node.symbol.type == STIXPatternParser.TimestampLiteral: elif node.symbol.type == self.parser_class.TimestampLiteral:
return TimestampConstant(node.getText()) return TimestampConstant(node.getText())
else: else:
return node return node
@ -321,12 +315,44 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
aggregate = [nextResult] aggregate = [nextResult]
return aggregate return aggregate
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21):
classes = {}
def create_pattern_object(pattern, module_suffix="", module_name=""): def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser21
super(STIXPatternVisitor21, self).__init__()
class STIXPatternVisitorForSTIX20(STIXPatternVisitor20, STIXPatternVisitorForSTIX2):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser20
super(STIXPatternVisitor20, self).__init__()
def create_pattern_object(pattern, module_suffix="", module_name="", version="2.1"):
""" """
Create a STIX pattern AST from a pattern string. Create a STIX pattern AST from a pattern string.
""" """
pattern_obj = Pattern(pattern) pattern_obj = Pattern21(pattern) if version == "2.1" else Pattern20(pattern)
builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) builder = STIXPatternVisitorForSTIX21(module_suffix, module_name) if version == "2.1" else STIXPatternVisitorForSTIX20(module_suffix, module_name)
return pattern_obj.visit(builder) return pattern_obj.visit(builder)

View File

@ -175,20 +175,34 @@ def test_greater_than():
assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]" assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]"
def test_parsing_greater_than():
patt_obj = create_pattern_object("[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]")
assert str(patt_obj) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]"
def test_less_than(): def test_less_than():
exp = stix2.LessThanComparisonExpression("file:size", 1024) exp = stix2.LessThanComparisonExpression("file:size", 1024)
assert str(exp) == "file:size < 1024" assert str(exp) == "file:size < 1024"
def test_parsing_less_than():
patt_obj = create_pattern_object("[file:size < 1024]")
assert str(patt_obj) == "[file:size < 1024]"
def test_greater_than_or_equal(): def test_greater_than_or_equal():
exp = stix2.GreaterThanEqualComparisonExpression( exp = stix2.GreaterThanEqualComparisonExpression(
"file:size", "file:size",
1024, 1024,
) )
assert str(exp) == "file:size >= 1024" assert str(exp) == "file:size >= 1024"
def test_parsing_greater_than_or_equal():
patt_obj = create_pattern_object("[file:size >= 1024]")
assert str(patt_obj) == "[file:size >= 1024]"
def test_less_than_or_equal(): def test_less_than_or_equal():
exp = stix2.LessThanEqualComparisonExpression( exp = stix2.LessThanEqualComparisonExpression(
"file:size", "file:size",
@ -197,6 +211,11 @@ def test_less_than_or_equal():
assert str(exp) == "file:size <= 1024" assert str(exp) == "file:size <= 1024"
def test_parsing_less_than_or_equal():
patt_obj = create_pattern_object("[file:size <= 1024]")
assert str(patt_obj) == "[file:size <= 1024]"
def test_not(): def test_not():
exp = stix2.LessThanComparisonExpression( exp = stix2.LessThanComparisonExpression(
"file:size", "file:size",
@ -257,6 +276,67 @@ def test_and_observable_expression():
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa
def test_parsing_and_observable_expression():
exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']")
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']"
def test_or_observable_expression():
exp1 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1007"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Peter",
),
])
exp2 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1008"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Paul",
),
])
exp3 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1009"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Mary",
),
])
exp = stix2.OrObservationExpression([
stix2.ObservationExpression(exp1),
stix2.ObservationExpression(exp2),
stix2.ObservationExpression(exp3),
])
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa
def test_parsing_or_observable_expression():
exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']") # noqa
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
def test_invalid_and_observable_expression(): def test_invalid_and_observable_expression():
with pytest.raises(ValueError): with pytest.raises(ValueError):
stix2.AndBooleanExpression([ stix2.AndBooleanExpression([
@ -286,6 +366,11 @@ def test_hex():
assert str(exp) == "[file:mime_type = 'image/bmp' AND file:magic_number_hex = h'ffd8']" assert str(exp) == "[file:mime_type = 'image/bmp' AND file:magic_number_hex = h'ffd8']"
def test_parsing_hex():
patt_obj = create_pattern_object("[file:magic_number_hex = h'ffd8']")
assert str(patt_obj) == "[file:magic_number_hex = h'ffd8']"
def test_multiple_qualifiers(): def test_multiple_qualifiers():
exp_and = stix2.AndBooleanExpression([ exp_and = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression( stix2.EqualityComparisonExpression(
@ -334,6 +419,11 @@ def test_binary():
assert str(exp) == "artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q='" assert str(exp) == "artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q='"
def test_parsing_binary():
patt_obj = create_pattern_object("[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']")
assert str(patt_obj) == "[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']"
def test_list(): def test_list():
exp = stix2.InComparisonExpression( exp = stix2.InComparisonExpression(
"process:name", "process:name",
@ -499,7 +589,7 @@ def test_parsing_comparison_expression():
assert str(patt_obj) == "[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']" assert str(patt_obj) == "[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']"
def test_parsing_qualified_expression(): def test_parsing_repeat_and_within_qualified_expression():
patt_obj = create_pattern_object( patt_obj = create_pattern_object(
"[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS", "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS",
) )
@ -508,11 +598,25 @@ def test_parsing_qualified_expression():
) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS" ) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS"
def test_parsing_start_stop_qualified_expression():
patt_obj = create_pattern_object(
"[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'",
)
assert str(
patt_obj,
) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'"
def test_list_constant(): def test_list_constant():
patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]") patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]")
assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]" assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]"
def test_parsing_boolean():
patt_obj = create_pattern_object("[network-traffic:is_active = true]")
assert str(patt_obj) == "[network-traffic:is_active = true]"
def test_parsing_multiple_slashes_quotes(): def test_parsing_multiple_slashes_quotes():
patt_obj = create_pattern_object("[ file:name = 'weird_name\\'' ]") patt_obj = create_pattern_object("[ file:name = 'weird_name\\'' ]")
assert str(patt_obj) == "[file:name = 'weird_name\\'']" assert str(patt_obj) == "[file:name = 'weird_name\\'']"