diff --git a/stix2/pattern_visitor.py b/stix2/pattern_visitor.py index b2d7a53..e4601ba 100644 --- a/stix2/pattern_visitor.py +++ b/stix2/pattern_visitor.py @@ -1,16 +1,15 @@ import importlib import inspect -from antlr4 import CommonTokenStream, InputStream -from antlr4.tree.Trees import Trees +from antlr4 import BailErrorStrategy, CommonTokenStream, InputStream +import antlr4.error.Errors import six -from stix2patterns.exceptions import ParseException +from stix2patterns.exceptions import ParseException, ParserErrorListener from stix2patterns.grammars.STIXPatternLexer import STIXPatternLexer from stix2patterns.grammars.STIXPatternParser import ( STIXPatternParser, TerminalNode, ) from stix2patterns.grammars.STIXPatternVisitor import STIXPatternVisitor -from stix2patterns.validator import STIXPatternErrorListener from .patterns import * from .patterns import _BooleanExpression @@ -328,20 +327,12 @@ class STIXPatternVisitorForSTIX2(STIXPatternVisitor): def create_pattern_object(pattern, module_suffix="", module_name=""): """ - Validates a pattern against the STIX Pattern grammar. Error messages are - returned in a list. The test passed if the returned list is empty. + Create a STIX pattern AST from a pattern string. """ - start = '' - if isinstance(pattern, six.string_types): - start = pattern[:2] - pattern = InputStream(pattern) + pattern = InputStream(pattern) - if not start: - start = pattern.readline()[:2] - pattern.seek(0) - - parseErrListener = STIXPatternErrorListener() + parseErrListener = ParserErrorListener() lexer = STIXPatternLexer(pattern) # it always adds a console listener by default... remove it. @@ -350,6 +341,7 @@ def create_pattern_object(pattern, module_suffix="", module_name=""): stream = CommonTokenStream(lexer) parser = STIXPatternParser(stream) + parser._errHandler = BailErrorStrategy() parser.buildParseTrees = True # it always adds a console listener by default... remove it. @@ -363,6 +355,15 @@ def create_pattern_object(pattern, module_suffix="", module_name=""): if lit_name == u"": parser.literalNames[i] = parser.symbolicNames[i] - tree = parser.pattern() + try: + tree = parser.pattern() + except antlr4.error.Errors.ParseCancellationException as e: + real_exc = e.args[0] + parser._errHandler.reportError(parser, real_exc) + six.raise_from( + ParseException(parseErrListener.error_message), + real_exc, + ) + builder = STIXPatternVisitorForSTIX2(module_suffix, module_name) return builder.visit(tree)