From 8f1a87886f13900f96fc0aa8a70c71e3989e8b7e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Mar 2020 10:35:43 +0100 Subject: [PATCH] Don't use POSITION to detect "finished connecting". In a Redis world we won't necessarily get one POSITION per stream at the start of the connection, so we rejig our "streams connecting" logic. --- synapse/replication/tcp/protocol.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 6c2258bae7..0df8f52777 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -564,10 +564,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set(STREAMS_MAP) # type: Set[str] + # Set of streams that we're currently catching up with. + self.streams_connecting = set() # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. @@ -589,6 +587,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + self.handler.finished_connecting() async def on_SERVER(self, cmd): if cmd.data != self.server_name: @@ -623,6 +622,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) return + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time. + self.streams_connecting.add(cmd.stream_name) + # Find where we previously streamed up to. current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) if current_token is None: @@ -648,8 +651,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): await self.handler.on_position(cmd.stream_name, cmd.token) self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() # Handle any RDATA that came in while we were catching up. rows = self.pending_batches.pop(cmd.stream_name, [])