Don't use POSITION to detect "finished connecting".

In a Redis world we won't necessarily get one POSITION per stream at the
start of the connection, so we rejig our "streams connecting" logic.
pull/7185/head
Erik Johnston 2020-03-31 10:35:43 +01:00
parent 5b1e760f1a
commit 8f1a87886f
1 changed files with 7 additions and 6 deletions

View File

@ -564,10 +564,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Set of streams that we're currently catching up with.
self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
@ -589,6 +587,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
self.handler.finished_connecting()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
@ -623,6 +622,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time.
self.streams_connecting.add(cmd.stream_name)
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
@ -648,8 +651,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])