make work for basic pattern objects

add_visitor
Richard Piazza 2018-11-19 21:24:33 -05:00
parent d5d65535a3
commit 3384507405
5 changed files with 70 additions and 26 deletions

28
.gitignore vendored
View File

@ -68,3 +68,31 @@ cache.sqlite
# PyCharm
.idea/
### macOS template
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

View File

@ -1,4 +1,3 @@
import stix2
import six
from stix2patterns.grammars.STIXPatternParser import *
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
@ -9,6 +8,7 @@ from stix2patterns.validator import STIXPatternErrorListener
import importlib
import inspect
from .patterns import *
def collapse_lists(lists):
result = []
@ -19,6 +19,12 @@ def collapse_lists(lists):
result.append(c)
return result
def quote_if_needed(x):
if x.find("-") != -1:
return "'" + x + "'"
else:
return x
# This class defines a complete generic visitor for a parse tree produced by STIXPatternParser.
@ -26,11 +32,14 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
classes = {}
def __init__(self, module_suffix, module_name):
self.module_suffix = module_suffix
if STIXPatternVisitorForSTIX2.classes == {}:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
if module_suffix and module_name:
self.module_suffix = module_suffix
if STIXPatternVisitorForSTIX2.classes == {}:
module = importlib.import_module(module_name)
for k, c in inspect.getmembers(module, inspect.isclass):
STIXPatternVisitorForSTIX2.classes[k] = c
else:
self.module_suffix = None
super(STIXPatternVisitor, self).__init__()
def get_class(self, class_name):
@ -55,7 +64,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
if len(children) == 1:
return children[0]
else:
return stix2.FollowedByObservationExpression([children[0], children[2]])
return FollowedByObservationExpression([children[0], children[2]])
# Visit a parse tree produced by STIXPatternParser#observationExpressionOr.
def visitObservationExpressionOr(self, ctx):
@ -158,12 +167,12 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
# Visit a parse tree produced by STIXPatternParser#propTestIsSubset.
def visitPropTestIsSubset(self, ctx):
children = self.visitChildren(ctx)
return stix2.IsSubsetComparisonExpression(children[0], children[2])
return IsSubsetComparisonExpression(children[0], children[2])
# Visit a parse tree produced by STIXPatternParser#propTestIsSuperset.
def visitPropTestIsSuperset(self, ctx):
children = self.visitChildren(ctx)
return stix2.IsSupersetComparisonExpression(children[0], children[2])
return IsSupersetComparisonExpression(children[0], children[2])
# Visit a parse tree produced by STIXPatternParser#propTestParen.
def visitPropTestParen(self, ctx):
@ -173,17 +182,17 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
# Visit a parse tree produced by STIXPatternParser#startStopQualifier.
def visitStartStopQualifier(self, ctx):
children = self.visitChildren(ctx)
return stix2.StartStopQualifier(children[1], children[3])
return StartStopQualifier(children[1], children[3])
# Visit a parse tree produced by STIXPatternParser#withinQualifier.
def visitWithinQualifier(self, ctx):
children = self.visitChildren(ctx)
return stix2.WithinQualifier(children[1])
return WithinQualifier(children[1])
# Visit a parse tree produced by STIXPatternParser#repeatedQualifier.
def visitRepeatedQualifier(self, ctx):
children = self.visitChildren(ctx)
return stix2.RepeatQualifier(children[1])
return RepeatQualifier(children[1])
# Visit a parse tree produced by STIXPatternParser#objectPath.
def visitObjectPath(self, ctx):
@ -194,11 +203,11 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
while i < len(flat_list):
current = flat_list[i]
if i == len(flat_list)-1:
property_path.append(current)
property_path.append(quote_if_needed(current))
break
next = flat_list[i+1]
if isinstance(next, TerminalNode):
property_path.append(stix2.ListObjectPathComponent(current.property_name, next.getText()))
property_path.append(ListObjectPathComponent(current.property_name, next.getText()))
i += 2
else:
property_path.append(current)
@ -217,7 +226,7 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
# if step.endswith("_ref"):
# return stix2.ReferenceObjectPathComponent(step)
# else:
return stix2.BasicObjectPathComponent(step)
return BasicObjectPathComponent(step)
# Visit a parse tree produced by STIXPatternParser#indexPathStep.
def visitIndexPathStep(self, ctx):
@ -231,11 +240,11 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
# Visit a parse tree produced by STIXPatternParser#keyPathStep.
def visitKeyPathStep(self, ctx):
children = self.visitChildren(ctx)
if isinstance(children[1], stix2.StringConstant):
if isinstance(children[1], StringConstant):
# special case for hashes
return children[1].value
else:
return stix2.BasicObjectPathComponent(children[1].getText(), is_key=True)
return BasicObjectPathComponent(children[1].getText(), is_key=True)
# Visit a parse tree produced by STIXPatternParser#setLiteral.
def visitSetLiteral(self, ctx):
@ -253,19 +262,19 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def visitTerminal(self, node):
if node.symbol.type == STIXPatternParser.IntPosLiteral or node.symbol.type == STIXPatternParser.IntNegLiteral:
return stix2.IntegerConstant(node.getText())
return IntegerConstant(node.getText())
elif node.symbol.type == STIXPatternParser.FloatPosLiteral or node.symbol.type == STIXPatternParser.FloatNegLiteral:
return stix2.FloatConstant(node.getText())
return FloatConstant(node.getText())
elif node.symbol.type == STIXPatternParser.HexLiteral:
return stix2.HexConstant(node.getText())
return HexConstant(node.getText())
elif node.symbol.type == STIXPatternParser.BinaryLiteral:
return stix2.BinaryConstant(node.getText())
return BinaryConstant(node.getText())
elif node.symbol.type == STIXPatternParser.StringLiteral:
return stix2.StringConstant(node.getText().strip('\''))
return StringConstant(node.getText().strip('\''))
elif node.symbol.type == STIXPatternParser.BoolLiteral:
return stix2.BooleanConstant(node.getText())
return BooleanConstant(node.getText())
elif node.symbol.type == STIXPatternParser.TimestampLiteral:
return stix2.TimestampConstant(node.getText())
return TimestampConstant(node.getText())
# TODO: timestamp
else:
return node

View File

@ -76,6 +76,8 @@ from .patterns import (
from .utils import new_version, revoke
from .version import __version__
from .STIXPatternVisitor import create_pattern_object
_collect_stix2_mappings()
DEFAULT_VERSION = '2.1' # Default version will always be the latest STIX 2.X version

View File

@ -12,7 +12,6 @@ from .utils import parse_into_datetime
def escape_quotes_and_backslashes(s):
return s.replace(u'\\', u'\\\\').replace(u"'", u"\\'")
class _Constant(object):
pass

View File

@ -4,7 +4,6 @@ import pytest
import stix2
def test_create_comparison_expression():
exp = stix2.EqualityComparisonExpression("file:hashes.'SHA-256'",
@ -377,3 +376,10 @@ def test_make_constant_already_a_constant():
str_const = stix2.StringConstant('Foo')
result = stix2.patterns.make_constant(str_const)
assert result is str_const
def test_parsing_expression():
patt_obj = stix2.create_pattern_object("[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']")
assert str(patt_obj) == "[file:hashes.'SHA-256' = 'aec070645fe53ee3b3763059376134f058cc337247c978add178b6ccdfb0019f']"