chg: [internal] Avoid double json decoding

pull/619/head
Jakub Onderka 2023-06-27 18:07:38 +02:00
parent 36b916916a
commit f1434bed90
2 changed files with 24 additions and 20 deletions

View File

@ -183,21 +183,28 @@ class QueryModule(tornado.web.RequestHandler):
executor = ThreadPoolExecutor(nb_threads) executor = ThreadPoolExecutor(nb_threads)
@run_on_executor @run_on_executor
def run_request(self, module, jsonpayload): def run_request(self, dict_payload, json_payload):
log.debug('MISP QueryModule request {0}'.format(jsonpayload)) if log.isEnabledFor(logging.DEBUG):
response = mhandlers[module].handler(q=jsonpayload) log.debug('MISP QueryModule request {0}'.format(json_payload))
module = dict_payload['module']
module_handler = mhandlers[module]
if hasattr(module_handler, 'REQUIRE_DICT') and module_handler.REQUIRE_DICT:
response = module_handler.handler(q=dict_payload)
else:
response = module_handler.handler(q=json_payload)
return json.dumps(response) return json.dumps(response)
@tornado.gen.coroutine @tornado.gen.coroutine
def post(self): def post(self):
try: try:
jsonpayload = self.request.body.decode('utf-8') json_payload = self.request.body.decode('utf-8')
dict_payload = json.loads(jsonpayload) dict_payload = json.loads(json_payload)
if dict_payload.get('timeout'): if dict_payload.get('timeout'):
timeout = datetime.timedelta(seconds=int(dict_payload.get('timeout'))) timeout = datetime.timedelta(seconds=int(dict_payload.get('timeout')))
else: else:
timeout = datetime.timedelta(seconds=300) timeout = datetime.timedelta(seconds=300)
response = yield tornado.gen.with_timeout(timeout, self.run_request(dict_payload['module'], jsonpayload)) response = yield tornado.gen.with_timeout(timeout, self.run_request(dict_payload, json_payload))
self.write(response) self.write(response)
except tornado.gen.TimeoutError: except tornado.gen.TimeoutError:
log.warning('Timeout on {} '.format(dict_payload['module'])) log.warning('Timeout on {} '.format(dict_payload['module']))
@ -223,15 +230,15 @@ def main():
global loaded_modules global loaded_modules
signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGTERM, handle_signal)
argParser = argparse.ArgumentParser(description='misp-modules server', formatter_class=argparse.RawTextHelpFormatter) arg_parser = argparse.ArgumentParser(description='misp-modules server', formatter_class=argparse.RawTextHelpFormatter)
argParser.add_argument('-t', default=False, action='store_true', help='Test mode') arg_parser.add_argument('-t', default=False, action='store_true', help='Test mode')
argParser.add_argument('-s', default=False, action='store_true', help='Run a system install (package installed via pip)') arg_parser.add_argument('-s', default=False, action='store_true', help='Run a system install (package installed via pip)')
argParser.add_argument('-d', default=False, action='store_true', help='Enable debugging') arg_parser.add_argument('-d', default=False, action='store_true', help='Enable debugging')
argParser.add_argument('-p', default=6666, help='misp-modules TCP port (default 6666)') arg_parser.add_argument('-p', default=6666, help='misp-modules TCP port (default 6666)')
argParser.add_argument('-l', default='localhost', help='misp-modules listen address (default localhost)') arg_parser.add_argument('-l', default='localhost', help='misp-modules listen address (default localhost)')
argParser.add_argument('-m', default=[], action='append', help='Register a custom module') arg_parser.add_argument('-m', default=[], action='append', help='Register a custom module')
argParser.add_argument('--devel', default=False, action='store_true', help='''Start in development mode, enable debug, start only the module(s) listed in -m.\nExample: -m misp_modules.modules.expansion.bgpranking''') arg_parser.add_argument('--devel', default=False, action='store_true', help='''Start in development mode, enable debug, start only the module(s) listed in -m.\nExample: -m misp_modules.modules.expansion.bgpranking''')
args = argParser.parse_args() args = arg_parser.parse_args()
port = args.p port = args.p
listen = args.l listen = args.l
if args.devel: if args.devel:

View File

@ -58,12 +58,9 @@ def connect_to_clamav(connection_string: str) -> clamd.ClamdNetworkSocket:
raise Exception("ClamAV connection string is invalid. It must be unix socket path with 'unix://' prefix or IP:PORT.") raise Exception("ClamAV connection string is invalid. It must be unix socket path with 'unix://' prefix or IP:PORT.")
def handler(q=False): REQUIRE_DICT = True
if q is False:
return False
request = json.loads(q)
def handler(request: dict):
connection_string: str = request["config"].get("connection") connection_string: str = request["config"].get("connection")
if not connection_string: if not connection_string:
return {"error": "No ClamAV connection string provided"} return {"error": "No ClamAV connection string provided"}