chg: [internal] Add support for orjson for zmq

pull/9491/head
Jakub Onderka 2024-01-13 12:52:44 +01:00
parent b7d11d3772
commit 7f200f6f17
1 changed files with 47 additions and 32 deletions

View File

@ -4,15 +4,18 @@ from zmq.auth.thread import ThreadAuthenticator
from zmq.utils.monitor import recv_monitor_message
import sys
import redis
import json
import os
import time
import threading
import logging
import typing
import argparse
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s:%(message)s")
try:
import orjson as json
except ImportError:
import json
def check_pid(pid):
@ -55,10 +58,11 @@ class MispZmq:
socket = None
pidfile = None
r: redis.StrictRedis
redis: redis.StrictRedis
namespace: str
def __init__(self):
def __init__(self, debug=False):
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
self._logger = logging.getLogger()
self.tmp_location = Path(__file__).parent.parent / "tmp"
@ -67,7 +71,7 @@ class MispZmq:
with open(self.pidfile.as_posix()) as f:
pid = f.read()
if check_pid(pid):
raise Exception("mispzmq already running on PID {}".format(pid))
raise Exception(f"mispzmq already running on PID {pid}")
else:
# Cleanup
self.pidfile.unlink()
@ -77,17 +81,18 @@ class MispZmq:
raise Exception("The settings file is missing.")
def _setup(self):
with open((self.tmp_location / "mispzmq_settings.json").as_posix()) as settings_file:
self.settings = json.load(settings_file)
with open((self.tmp_location / "mispzmq_settings.json").as_posix(), 'rb') as settings_file:
self.settings = json.loads(settings_file.read())
self.namespace = self.settings["redis_namespace"]
# Check if TLS is being used with Redis host
redis_host = self.settings["redis_host"]
redis_ssl = redis_host.startswith("tls://")
if redis_host.startswith("tls://"):
redis_host = redis_host[6:]
self.r = redis.StrictRedis(host=redis_host, db=self.settings["redis_database"],
self.redis = redis.StrictRedis(host=redis_host, db=self.settings["redis_database"],
password=self.settings["redis_password"], port=self.settings["redis_port"],
decode_responses=True, ssl=redis_ssl)
ssl=redis_ssl)
self.timestamp_settings = time.time()
self._logger.debug("Connected to Redis {}:{}/{}".format(self.settings["redis_host"], self.settings["redis_port"],
self.settings["redis_database"]))
@ -122,34 +127,38 @@ class MispZmq:
self.socket.disable_monitor()
self.monitor_thread = None
def _handle_command(self, command):
if command == "kill":
def _handle_command(self, command: bytes):
if command == b"kill":
self._logger.info("Kill command received, shutting down.")
self.clean()
sys.exit()
elif command == "reload":
elif command == b"reload":
self._logger.info("Reload command received, reloading settings from file.")
self._setup()
self._setup_zmq()
elif command == "status":
elif command == b"status":
self._logger.info("Status command received, responding with latest stats.")
self.r.delete("{}:status".format(self.namespace))
self.r.lpush("{}:status".format(self.namespace),
self.redis.delete(f"{self.namespace}:status")
self.redis.lpush(f"{self.namespace}:status",
json.dumps({"timestamp": time.time(),
"timestampSettings": self.timestamp_settings,
"publishCount": self.publish_count,
"messageCount": self.message_count}))
else:
self._logger.warning("Received invalid command '{}'.".format(command))
self._logger.warning(f"Received invalid command '{command}'.")
def _create_pid_file(self):
with open(self.pidfile.as_posix(), "w") as f:
f.write(str(os.getpid()))
def _pub_message(self, topic, data):
self.socket.send_string("{} {}".format(topic, data))
def _pub_message(self, topic: bytes, data: typing.Union[str, bytes]):
data_to_send = bytearray()
data_to_send.extend(topic)
data_to_send.extend(b" ")
data_to_send.extend(data.encode("utf-8") if isinstance(data, str) else data)
self.socket.send(bytes(data_to_send))
def clean(self):
if self.monitor_thread:
@ -179,12 +188,14 @@ class MispZmq:
"misp_json_tag", "misp_json_warninglist", "misp_json_workflow"
]
lists = ["{}:command".format(self.namespace)]
lists = [f"{self.namespace}:command"]
for topic in topics:
lists.append("{}:data:{}".format(self.namespace, topic))
lists.append(f"{self.namespace}:data:{topic}")
key_prefix = f"{self.namespace}:".encode("utf-8")
while True:
data = self.r.blpop(lists, timeout=10)
data = self.redis.blpop(lists, timeout=10)
if data is None:
# redis timeout expired
@ -195,26 +206,30 @@ class MispZmq:
"status": status_array[status_entry],
"uptime": current_time - int(self.timestamp_settings)
}
self._pub_message("misp_json_self", json.dumps(status_message))
self._logger.debug("No message received for 10 seconds, sending ZMQ status message.")
self._pub_message(b"misp_json_self", json.dumps(status_message))
self._logger.debug("No message received from Redis for 10 seconds, sending ZMQ status message.")
else:
key, value = data
key = key.replace("{}:".format(self.namespace), "")
if key == "command":
key = key.replace(key_prefix, b"")
if key == b"command":
self._handle_command(value)
elif key.startswith("data:"):
topic = key.split(":")[1]
self._logger.debug("Received data for topic '{}', sending to ZMQ.".format(topic))
elif key.startswith(b"data:"):
topic = key.split(b":", 1)[1]
self._logger.debug("Received data for topic %s, sending to ZMQ.", topic)
self._pub_message(topic, value)
self.message_count += 1
if topic == "misp_json":
if topic == b"misp_json":
self.publish_count += 1
else:
self._logger.warning("Received invalid message '{}'.".format(key))
self._logger.warning("Received invalid message type %s.", key)
if __name__ == "__main__":
mzq = MispZmq()
arg_parser = argparse.ArgumentParser(description="MISP ZeroMQ PUB server")
arg_parser.add_argument("--debug", action="store_true", help="Enable debugging messages")
parsed = arg_parser.parse_args()
mzq = MispZmq(parsed.debug)
try:
mzq.main()
except KeyboardInterrupt: