From 7f200f6f17459f0dbb7c09f7d6b7fe448b90be5f Mon Sep 17 00:00:00 2001 From: Jakub Onderka Date: Sat, 13 Jan 2024 12:52:44 +0100 Subject: [PATCH] chg: [internal] Add support for orjson for zmq --- app/files/scripts/mispzmq/mispzmq.py | 79 +++++++++++++++++----------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/app/files/scripts/mispzmq/mispzmq.py b/app/files/scripts/mispzmq/mispzmq.py index 4c67684ce..0bfe6a890 100644 --- a/app/files/scripts/mispzmq/mispzmq.py +++ b/app/files/scripts/mispzmq/mispzmq.py @@ -4,15 +4,18 @@ from zmq.auth.thread import ThreadAuthenticator from zmq.utils.monitor import recv_monitor_message import sys import redis -import json import os import time import threading import logging +import typing +import argparse from pathlib import Path - -logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") +try: + import orjson as json +except ImportError: + import json def check_pid(pid): @@ -55,10 +58,11 @@ class MispZmq: socket = None pidfile = None - r: redis.StrictRedis + redis: redis.StrictRedis namespace: str - def __init__(self): + def __init__(self, debug=False): + logging.basicConfig(level=logging.DEBUG if debug else logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") self._logger = logging.getLogger() self.tmp_location = Path(__file__).parent.parent / "tmp" @@ -67,7 +71,7 @@ class MispZmq: with open(self.pidfile.as_posix()) as f: pid = f.read() if check_pid(pid): - raise Exception("mispzmq already running on PID {}".format(pid)) + raise Exception(f"mispzmq already running on PID {pid}") else: # Cleanup self.pidfile.unlink() @@ -77,17 +81,18 @@ class MispZmq: raise Exception("The settings file is missing.") def _setup(self): - with open((self.tmp_location / "mispzmq_settings.json").as_posix()) as settings_file: - self.settings = json.load(settings_file) + with open((self.tmp_location / "mispzmq_settings.json").as_posix(), 'rb') as settings_file: + self.settings = json.loads(settings_file.read()) + self.namespace = self.settings["redis_namespace"] # Check if TLS is being used with Redis host redis_host = self.settings["redis_host"] redis_ssl = redis_host.startswith("tls://") if redis_host.startswith("tls://"): redis_host = redis_host[6:] - self.r = redis.StrictRedis(host=redis_host, db=self.settings["redis_database"], + self.redis = redis.StrictRedis(host=redis_host, db=self.settings["redis_database"], password=self.settings["redis_password"], port=self.settings["redis_port"], - decode_responses=True, ssl=redis_ssl) + ssl=redis_ssl) self.timestamp_settings = time.time() self._logger.debug("Connected to Redis {}:{}/{}".format(self.settings["redis_host"], self.settings["redis_port"], self.settings["redis_database"])) @@ -122,34 +127,38 @@ class MispZmq: self.socket.disable_monitor() self.monitor_thread = None - def _handle_command(self, command): - if command == "kill": + def _handle_command(self, command: bytes): + if command == b"kill": self._logger.info("Kill command received, shutting down.") self.clean() sys.exit() - elif command == "reload": + elif command == b"reload": self._logger.info("Reload command received, reloading settings from file.") self._setup() self._setup_zmq() - elif command == "status": + elif command == b"status": self._logger.info("Status command received, responding with latest stats.") - self.r.delete("{}:status".format(self.namespace)) - self.r.lpush("{}:status".format(self.namespace), + self.redis.delete(f"{self.namespace}:status") + self.redis.lpush(f"{self.namespace}:status", json.dumps({"timestamp": time.time(), "timestampSettings": self.timestamp_settings, "publishCount": self.publish_count, "messageCount": self.message_count})) else: - self._logger.warning("Received invalid command '{}'.".format(command)) + self._logger.warning(f"Received invalid command '{command}'.") def _create_pid_file(self): with open(self.pidfile.as_posix(), "w") as f: f.write(str(os.getpid())) - def _pub_message(self, topic, data): - self.socket.send_string("{} {}".format(topic, data)) + def _pub_message(self, topic: bytes, data: typing.Union[str, bytes]): + data_to_send = bytearray() + data_to_send.extend(topic) + data_to_send.extend(b" ") + data_to_send.extend(data.encode("utf-8") if isinstance(data, str) else data) + self.socket.send(bytes(data_to_send)) def clean(self): if self.monitor_thread: @@ -179,12 +188,14 @@ class MispZmq: "misp_json_tag", "misp_json_warninglist", "misp_json_workflow" ] - lists = ["{}:command".format(self.namespace)] + lists = [f"{self.namespace}:command"] for topic in topics: - lists.append("{}:data:{}".format(self.namespace, topic)) + lists.append(f"{self.namespace}:data:{topic}") + + key_prefix = f"{self.namespace}:".encode("utf-8") while True: - data = self.r.blpop(lists, timeout=10) + data = self.redis.blpop(lists, timeout=10) if data is None: # redis timeout expired @@ -195,26 +206,30 @@ class MispZmq: "status": status_array[status_entry], "uptime": current_time - int(self.timestamp_settings) } - self._pub_message("misp_json_self", json.dumps(status_message)) - self._logger.debug("No message received for 10 seconds, sending ZMQ status message.") + self._pub_message(b"misp_json_self", json.dumps(status_message)) + self._logger.debug("No message received from Redis for 10 seconds, sending ZMQ status message.") else: key, value = data - key = key.replace("{}:".format(self.namespace), "") - if key == "command": + key = key.replace(key_prefix, b"") + if key == b"command": self._handle_command(value) - elif key.startswith("data:"): - topic = key.split(":")[1] - self._logger.debug("Received data for topic '{}', sending to ZMQ.".format(topic)) + elif key.startswith(b"data:"): + topic = key.split(b":", 1)[1] + self._logger.debug("Received data for topic %s, sending to ZMQ.", topic) self._pub_message(topic, value) self.message_count += 1 - if topic == "misp_json": + if topic == b"misp_json": self.publish_count += 1 else: - self._logger.warning("Received invalid message '{}'.".format(key)) + self._logger.warning("Received invalid message type %s.", key) if __name__ == "__main__": - mzq = MispZmq() + arg_parser = argparse.ArgumentParser(description="MISP ZeroMQ PUB server") + arg_parser.add_argument("--debug", action="store_true", help="Enable debugging messages") + parsed = arg_parser.parse_args() + + mzq = MispZmq(parsed.debug) try: mzq.main() except KeyboardInterrupt: