Only accept RDATA commands if we've caught up with stream.

pull/7185/head
Erik Johnston 2020-04-01 17:54:17 +01:00
parent 1ebfa39a73
commit ca9778cedf
1 changed files with 9 additions and 6 deletions

View File

@ -52,8 +52,8 @@ class ReplicationCommandHandler:
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
# Set of streams that we're currently catching up with.
self._streams_connecting = set() # type: Set[str]
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
@ -91,7 +91,7 @@ class ReplicationCommandHandler:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise
if cmd.token is None or stream_name in self._streams_connecting:
if cmd.token is None or stream_name not in self._streams_connected:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self._pending_batches.setdefault(stream_name, []).append(row)
@ -120,8 +120,11 @@ class ReplicationCommandHandler:
return
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time.
self._streams_connecting.add(cmd.stream_name)
# to stop RDATA being handled at the same time by removing stream from
# list of connected streams. We also clear any batched up RDATA from
# before we got the POSITION.
self._streams_connected.discard(cmd.stream_name)
self._pending_batches.clear()
# We protect catching up with a linearizer in case the replicaiton
# connection reconnects under us.
@ -153,7 +156,7 @@ class ReplicationCommandHandler:
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
self._streams_connecting.discard(cmd.stream_name)
self._streams_connected.add(cmd.stream_name)
# Handle any RDATA that came in while we were catching up.
rows = self._pending_batches.pop(cmd.stream_name, [])