more pattern tests

master
Rich Piazza 2020-03-27 11:22:00 -04:00
parent e8035863b8
commit 202111acdf
2 changed files with 164 additions and 34 deletions

View File

@ -2,11 +2,16 @@ import importlib
import inspect
from stix2patterns.exceptions import ParseException
from stix2patterns.grammars.STIXPatternParser import (
STIXPatternParser, TerminalNode,
)
from stix2patterns.grammars.STIXPatternParser import TerminalNode
from stix2patterns.v20.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser20
from stix2patterns.v21.grammars.STIXPatternParser import STIXPatternParser as STIXPatternParser21
from stix2patterns.v20.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor20
from stix2patterns.v21.grammars.STIXPatternVisitor import STIXPatternVisitor as STIXPatternVisitor21
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
from stix2patterns.v20.pattern import Pattern
from stix2patterns.v20.pattern import Pattern as Pattern20
from stix2patterns.v21.pattern import Pattern as Pattern21
from .patterns import *
from .patterns import _BooleanExpression
@ -32,23 +37,12 @@ def remove_terminal_nodes(parse_tree_nodes):
return values
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
class STIXPatternVisitorForSTIX2():
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
super(STIXPatternVisitor, self).__init__()
def get_class(self, class_name):
if class_name in STIXPatternVisitorForSTIX2.classes:
return STIXPatternVisitorForSTIX2.classes[class_name]
@ -147,7 +141,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def visitPropTestEqual(self, ctx):
children = self.visitChildren(ctx)
operator = children[1].symbol.type
negated = operator != STIXPatternParser.EQ
negated = operator != self.parser_class.EQ
return self.instantiate(
"EqualityComparisonExpression", children[0], children[3 if len(children) > 3 else 2],
negated,
@ -157,22 +151,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def visitPropTestOrder(self, ctx):
children = self.visitChildren(ctx)
operator = children[1].symbol.type
if operator == STIXPatternParser.GT:
if operator == self.parser_class.GT:
return self.instantiate(
"GreaterThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == STIXPatternParser.LT:
elif operator == self.parser_class.LT:
return self.instantiate(
"LessThanComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == STIXPatternParser.GE:
elif operator == self.parser_class.GE:
return self.instantiate(
"GreaterThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
)
elif operator == STIXPatternParser.LE:
elif operator == self.parser_class.LE:
return self.instantiate(
"LessThanEqualComparisonExpression", children[0],
children[3 if len(children) > 3 else 2], False,
@ -294,22 +288,22 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
return children[0]
def visitTerminal(self, node):
if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral:
if node.symbol.type == self.parser_class.IntPosLiteral or node.symbol.type == self.parser_class.IntNegLiteral:
return IntegerConstant(node.getText())
elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral:
elif node.symbol.type == self.parser_class.FloatPosLiteral or node.symbol.type == self.parser_class.FloatNegLiteral:
return FloatConstant(node.getText())
elif node.symbol.type == STIXPatternParser.HexLiteral:
elif node.symbol.type == self.parser_class.HexLiteral:
return HexConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == STIXPatternParser.BinaryLiteral:
elif node.symbol.type == self.parser_class.BinaryLiteral:
return BinaryConstant(node.getText(), from_parse_tree=True)
elif node.symbol.type == STIXPatternParser.StringLiteral:
elif node.symbol.type == self.parser_class.StringLiteral:
if node.getText()[0] == "'" and node.getText()[-1] == "'":
return StringConstant(node.getText()[1:-1], from_parse_tree=True)
else:
raise ParseException("The pattern does not start and end with a single quote")
elif node.symbol.type == STIXPatternParser.BoolLiteral:
elif node.symbol.type == self.parser_class.BoolLiteral:
return BooleanConstant(node.getText())
elif node.symbol.type == STIXPatternParser.TimestampLiteral:
elif node.symbol.type == self.parser_class.TimestampLiteral:
return TimestampConstant(node.getText())
else:
return node
@ -321,12 +315,44 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
aggregate = [nextResult]
return aggregate
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
class STIXPatternVisitorForSTIX21(STIXPatternVisitorForSTIX2, STIXPatternVisitor21):
classes = {}
def create_pattern_object(pattern, module_suffix="", module_name=""):
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser21
super(STIXPatternVisitor21, self).__init__()
class STIXPatternVisitorForSTIX20(STIXPatternVisitor20, STIXPatternVisitorForSTIX2):
classes = {}
def __init__(self, module_suffix, module_name):
if module_suffix and module_name:
self.module_suffix = module_suffix
if not STIXPatternVisitorForSTIX2.classes:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
self.parser_class = STIXPatternParser20
super(STIXPatternVisitor20, self).__init__()
def create_pattern_object(pattern, module_suffix="", module_name="", version="2.1"):
"""
Create a STIX pattern AST from a pattern string.
"""
pattern_obj = Pattern(pattern)
builder = STIXPatternVisitorForSTIX2(module_suffix, module_name)
pattern_obj = Pattern21(pattern) if version == "2.1" else Pattern20(pattern)
builder = STIXPatternVisitorForSTIX21(module_suffix, module_name) if version == "2.1" else STIXPatternVisitorForSTIX20(module_suffix, module_name)
return pattern_obj.visit(builder)

View File

@ -175,20 +175,34 @@ def test_greater_than():
assert str(exp) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.0]"
def test_parsing_greater_than():
patt_obj = create_pattern_object("[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]")
assert str(patt_obj) == "[file:extensions.'windows-pebinary-ext'.sections[*].entropy > 7.478901]"
def test_less_than():
exp = stix2.LessThanComparisonExpression("file:size", 1024)
assert str(exp) == "file:size < 1024"
def test_parsing_less_than():
patt_obj = create_pattern_object("[file:size < 1024]")
assert str(patt_obj) == "[file:size < 1024]"
def test_greater_than_or_equal():
exp = stix2.GreaterThanEqualComparisonExpression(
"file:size",
1024,
)
assert str(exp) == "file:size >= 1024"
def test_parsing_greater_than_or_equal():
patt_obj = create_pattern_object("[file:size >= 1024]")
assert str(patt_obj) == "[file:size >= 1024]"
def test_less_than_or_equal():
exp = stix2.LessThanEqualComparisonExpression(
"file:size",
@ -197,6 +211,11 @@ def test_less_than_or_equal():
assert str(exp) == "file:size <= 1024"
def test_parsing_less_than_or_equal():
patt_obj = create_pattern_object("[file:size <= 1024]")
assert str(patt_obj) == "[file:size <= 1024]"
def test_not():
exp = stix2.LessThanComparisonExpression(
"file:size",
@ -257,6 +276,67 @@ def test_and_observable_expression():
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa
def test_parsing_and_observable_expression():
exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']")
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] AND [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']"
def test_or_observable_expression():
exp1 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1007"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Peter",
),
])
exp2 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1008"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Paul",
),
])
exp3 = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
"user-account:account_type",
"unix",
),
stix2.EqualityComparisonExpression(
"user-account:user_id",
stix2.StringConstant("1009"),
),
stix2.EqualityComparisonExpression(
"user-account:account_login",
"Mary",
),
])
exp = stix2.OrObservationExpression([
stix2.ObservationExpression(exp1),
stix2.ObservationExpression(exp2),
stix2.ObservationExpression(exp3),
])
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1009' AND user-account:account_login = 'Mary']" # noqa
def test_parsing_or_observable_expression():
exp = create_pattern_object("[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']") # noqa
assert str(exp) == "[user-account:account_type = 'unix' AND user-account:user_id = '1007' AND user-account:account_login = 'Peter'] OR [user-account:account_type = 'unix' AND user-account:user_id = '1008' AND user-account:account_login = 'Paul']" # noqa
def test_invalid_and_observable_expression():
with pytest.raises(ValueError):
stix2.AndBooleanExpression([
@ -286,6 +366,11 @@ def test_hex():
assert str(exp) == "[file:mime_type = 'image/bmp' AND file:magic_number_hex = h'ffd8']"
def test_parsing_hex():
patt_obj = create_pattern_object("[file:magic_number_hex = h'ffd8']")
assert str(patt_obj) == "[file:magic_number_hex = h'ffd8']"
def test_multiple_qualifiers():
exp_and = stix2.AndBooleanExpression([
stix2.EqualityComparisonExpression(
@ -334,6 +419,11 @@ def test_binary():
assert str(exp) == "artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q='"
def test_parsing_binary():
patt_obj = create_pattern_object("[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']")
assert str(patt_obj) == "[artifact:payload_bin = b'dGhpcyBpcyBhIHRlc3Q=']"
def test_list():
exp = stix2.InComparisonExpression(
"process:name",
@ -499,7 +589,7 @@ def test_parsing_comparison_expression():
assert str(patt_obj) == "[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']"
def test_parsing_qualified_expression():
def test_parsing_repeat_and_within_qualified_expression():
patt_obj = create_pattern_object(
"[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS",
)
@ -508,11 +598,25 @@ def test_parsing_qualified_expression():
) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] REPEATS 5 TIMES WITHIN 1800 SECONDS"
def test_parsing_start_stop_qualified_expression():
patt_obj = create_pattern_object(
"[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'",
)
assert str(
patt_obj,
) == "[network-traffic:dst_ref.type = 'domain-name' AND network-traffic:dst_ref.value = 'example.com'] START t'2016-06-01T00:00:00Z' STOP t'2017-03-12T08:30:00Z'"
def test_list_constant():
patt_obj = create_pattern_object("[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]")
assert str(patt_obj) == "[network-traffic:src_ref.value IN ('10.0.0.0', '10.0.0.1', '10.0.0.2')]"
def test_parsing_boolean():
patt_obj = create_pattern_object("[network-traffic:is_active = true]")
assert str(patt_obj) == "[network-traffic:is_active = true]"
def test_parsing_multiple_slashes_quotes():
patt_obj = create_pattern_object("[ file:name = 'weird_name\\'' ]")
assert str(patt_obj) == "[file:name = 'weird_name\\'']"