Fix stix2.pattern_visitor.create_pattern_object() so its

documentation at least isn't wrong, and it behaves better.
master
Michael Chisholm 2020-02-17 19:26:21 -05:00
parent 8aca39a0b0
commit cfb7c4c73b
1 changed files with 17 additions and 16 deletions

View File

@ -1,16 +1,15 @@
import importlib import importlib
import inspect import inspect
from antlr4 import CommonTokenStream, InputStream from antlr4 import BailErrorStrategy, CommonTokenStream, InputStream
from antlr4.tree.Trees import Trees import antlr4.error.Errors
import six import six
from stix2patterns.exceptions import ParseException from stix2patterns.exceptions import ParseException, ParserErrorListener
from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer
from stix2patterns.grammars.STIXPatternParser import ( from stix2patterns.grammars.STIXPatternParser import (
STIXPatternParser, TerminalNode, STIXPatternParser, TerminalNode,
) )
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
from stix2patterns.validator import STIXPatternErrorListener
from .patterns import * from .patterns import *
from .patterns import _BooleanExpression from .patterns import _BooleanExpression
@ -328,20 +327,12 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def create_pattern_object(pattern, module_suffix="", module_name=""): def create_pattern_object(pattern, module_suffix="", module_name=""):
""" """
Validates a pattern against the STIX Pattern grammar. Error messages are Create a STIX pattern AST from a pattern string.
returned in a list. The test passed if the returned list is empty.
""" """
start = '' pattern = InputStream(pattern)
if isinstance(pattern, six.string_types):
start = pattern[:2]
pattern = InputStream(pattern)
if not start: parseErrListener = ParserErrorListener()
start = pattern.readline()[:2]
pattern.seek(0)
parseErrListener = STIXPatternErrorListener()
lexer = STIXPatternLexer(pattern) lexer = STIXPatternLexer(pattern)
# it always adds a console listener by default... remove it. # it always adds a console listener by default... remove it.
@ -350,6 +341,7 @@ def create_pattern_object(pattern, module_suffix="", module_name=""):
stream = CommonTokenStream(lexer) stream = CommonTokenStream(lexer)
parser = STIXPatternParser(stream) parser = STIXPatternParser(stream)
parser._errHandler = BailErrorStrategy()
parser.buildParseTrees = True parser.buildParseTrees = True
# it always adds a console listener by default... remove it. # it always adds a console listener by default... remove it.
@ -363,6 +355,15 @@ def create_pattern_object(pattern, module_suffix="", module_name=""):
if lit_name == u"<INVALID>": if lit_name == u"<INVALID>":
parser.literalNames[i] = parser.symbolicNames[i] parser.literalNames[i] = parser.symbolicNames[i]
tree = parser.pattern() try:
tree = parser.pattern()
except antlr4.error.Errors.ParseCancellationException as e:
real_exc = e.args[0]
parser._errHandler.reportError(parser, real_exc)
six.raise_from(
ParseException(parseErrListener.error_message),
real_exc,
)
builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) builder = STIXPatternVisitorForSTIX2(module_suffix, module_name)
return builder.visit(tree) return builder.visit(tree)