diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 18a11eb6b9..4f72850543 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -20,7 +20,9 @@ from typing import Dict, List, Optional from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory + from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ( AbstractReplicationClientHandler, ClientReplicationStreamProtocol, @@ -56,7 +58,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name - self.handler = handler + self.command_handler = ReplicationCommandHandler(hs, handler) self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -69,7 +71,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.hs, self.client_name, self.server_name, self._clock, self.handler, + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..03ec4f1381 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +import logging +from typing import Any, Dict, List, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.commands import ( + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + SyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication. + """ + + def __init__(self, hs, handler): + self.handler = handler + + # Set of streams that we're currently catching up with. + self.streams_connecting = set() # type: Set[str] + + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self.pending_batches = {} # type: Dict[str, List[Any]] + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name in self.streams_connecting: + # I.e. this is part of a batch of updates for this stream. Batch + # until we get an update for the stream with a non None token + self.pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self.pending_batches.pop(stream_name, []) + rows.append(row) + await self.handler.on_rdata(stream_name, cmd.token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time. + self.streams_connecting.add(cmd.stream_name) + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + + self.streams_connecting.discard(cmd.stream_name) + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + + async def on_SYNC(self, cmd: SyncCommand): + self.handler.on_sync(cmd.data) + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. (Overriden by the synchrotron's only) + """ + return self.handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + return self.handler.update_connection(connection) + + def finished_connecting(self): + return self.handler.finished_connecting() diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0df8f52777..7dc2828ad1 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -550,7 +550,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler, ): BaseReplicationStreamProtocol.__init__(self, clock) @@ -558,7 +558,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - self.handler = handler + self.handler = command_handler self.streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() @@ -589,80 +589,37 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.handler.update_connection(self) self.handler.finished_connecting() + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # We're about to go and catch up with the stream, so mark as connecting - # to stop RDATA being handled at the same time. - self.streams_connecting.add(cmd.stream_name) - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) - - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) - - self.streams_connecting.discard(cmd.stream_name) - - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) - - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) - def replicate(self): """Send the subscription request to the server """ @@ -758,8 +715,3 @@ tcp_outbound_commands = LaterGauge( for k, count in iteritems(p.outbound_commands_counter) }, ) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -)