From 699ccf3f0e5ef31b66b1a301d329b584b4f23dc6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Mar 2020 11:38:12 +0100 Subject: [PATCH] Add replication data handler concept. This stops us having to subclass ReplicationClientHandler and override methods. --- synapse/app/generic_worker.py | 9 ++--- synapse/replication/tcp/client.py | 67 ++++++++++++++++++++++++------- synapse/server.py | 10 ++++- synapse/server.pyi | 4 ++ 4 files changed, 69 insertions(+), 21 deletions(-) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1ee266f7c5..588db40e86 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, @@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer): def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): + def build_replication_data_handler(self): return GenericWorkerReplicationHandler(self) def build_presence_handler(self): @@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer): return GenericWorkerTyping(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): +class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) @@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..18a11eb6b9 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory - from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.protocol import ( AbstractReplicationClientHandler, @@ -37,6 +36,10 @@ from .commands import ( UserSyncCommand, ) +MYPY = False +if MYPY: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -84,8 +87,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store: BaseSlavedStore): - self.store = store + def __init__(self, hs: "HomeServer"): + self.presence_handler = hs.get_presence_handler() + self.data_handler = hs.get_replication_data_handler() # The current connection. None if we are currently (re)connecting self.connection = None @@ -125,7 +129,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - self.store.process_replication_rows(stream_name, token, rows) + await self.data_handler.on_rdata(stream_name, token, rows) async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes @@ -133,7 +137,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): Can be overriden in subclasses to handle more. """ - self.store.process_replication_rows(stream_name, token, []) + await self.data_handler.on_position(stream_name, token) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting @@ -156,22 +160,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): map from stream name to the most recent update we have for that stream (ie, the point we want to start replicating from) """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - return args + return self.data_handler.get_streams_to_replicate() def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the currently syncing users. (Overriden by the synchrotron's only) """ - return [] + return self.presence_handler.get_currently_syncing_users() def send_command(self, cmd): """Send a command to master (when we get establish a connection if we @@ -245,3 +242,45 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): # server. if self.factory: self.factory.resetDelay() + + +class ReplicationDataHandler: + """A replication data handler that calls slave data stores. + """ + + def __init__(self, store: BaseSlavedStore): + self.store = store + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + By default this just pokes the slave store. Can be overridden in subclasses to + handle more. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + self.store.process_replication_rows(stream_name, token, rows) + + def get_streams_to_replicate(self) -> Dict[str, int]: + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/server.py b/synapse/server.py index cd86475d6b..b828be913c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,6 +87,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ( + ReplicationClientHandler, + ReplicationDataHandler, +) from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -206,6 +210,7 @@ class HomeServer(object): "password_policy_handler", "storage", "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -468,7 +473,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationClientHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -562,6 +567,9 @@ class HomeServer(object): def build_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) + def build_replication_data_handler(self): + return ReplicationDataHandler(self.get_datastore()) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9d1dfa71e7..5e6298ab13 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -108,6 +108,10 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: + pass def get_federation_registry( self, ) -> synapse.federation.federation_server.FederationHandlerRegistry: