Fix stix2.pattern_visitor.create_pattern_object() so its

documentation at least isn't wrong, and it behaves better.
master
Michael Chisholm 2020-02-17 19:26:21 -05:00
parent 8aca39a0b0
commit cfb7c4c73b
1 changed files with 17 additions and 16 deletions

View File

@ -1,16 +1,15 @@
import importlib
import inspect
from antlr4 import CommonTokenStream, InputStream
from antlr4.tree.Trees import Trees
from antlr4 import BailErrorStrategy, CommonTokenStream, InputStream
import antlr4.error.Errors
import six
from stix2patterns.exceptions import ParseException
from stix2patterns.exceptions import ParseException, ParserErrorListener
from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer
from stix2patterns.grammars.STIXPatternParser import (
STIXPatternParser, TerminalNode,
)
from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor
from stix2patterns.validator import STIXPatternErrorListener
from .patterns import *
from .patterns import _BooleanExpression
@ -328,20 +327,12 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor):
def create_pattern_object(pattern, module_suffix="", module_name=""):
"""
Validates a pattern against the STIX Pattern grammar. Error messages are
returned in a list. The test passed if the returned list is empty.
Create a STIX pattern AST from a pattern string.
"""
start = ''
if isinstance(pattern, six.string_types):
start = pattern[:2]
pattern = InputStream(pattern)
pattern = InputStream(pattern)
if not start:
start = pattern.readline()[:2]
pattern.seek(0)
parseErrListener = STIXPatternErrorListener()
parseErrListener = ParserErrorListener()
lexer = STIXPatternLexer(pattern)
# it always adds a console listener by default... remove it.
@ -350,6 +341,7 @@ def create_pattern_object(pattern, module_suffix="", module_name=""):
stream = CommonTokenStream(lexer)
parser = STIXPatternParser(stream)
parser._errHandler = BailErrorStrategy()
parser.buildParseTrees = True
# it always adds a console listener by default... remove it.
@ -363,6 +355,15 @@ def create_pattern_object(pattern, module_suffix="", module_name=""):
if lit_name == u"<INVALID>":
parser.literalNames[i] = parser.symbolicNames[i]
tree = parser.pattern()
try:
tree = parser.pattern()
except antlr4.error.Errors.ParseCancellationException as e:
real_exc = e.args[0]
parser._errHandler.reportError(parser, real_exc)
six.raise_from(
ParseException(parseErrListener.error_message),
real_exc,
)
builder = STIXPatternVisitorForSTIX2(module_suffix, module_name)
return builder.visit(tree)