Compare commits

..

No commits in common. "48735830229c5cbe9339fb6897bf57c10c98d74b" and "534bd868e50cd1fe2efd52d4ec0ec92452ac6a6b" have entirely different histories.

2 changed files with 46 additions and 21 deletions

View File

@ -77,7 +77,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
class ReplicationDataHandler: class ReplicationDataHandler:
"""Handles incoming stream updates from replication. """A replication data handler handles incoming stream updates from replication.
This instance notifies the slave data store about updates. Can be subclassed This instance notifies the slave data store about updates. Can be subclassed
to handle updates in additional ways. to handle updates in additional ways.

View File

@ -45,8 +45,7 @@ inbound_rdata_count = Counter(
class ReplicationCommandHandler: class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands """Handles incoming commands from replication.
back out to connections.
""" """
def __init__(self, hs): def __init__(self, hs):
@ -93,9 +92,8 @@ class ReplicationCommandHandler:
raise raise
if cmd.token is None or stream_name not in self._streams_connected: if cmd.token is None or stream_name not in self._streams_connected:
# I.e. either this is part of a batch of updates for this stream (in # I.e. this is part of a batch of updates for this stream. Batch
# which case batch until we get an update for the stream with a non # until we get an update for the stream with a non None token
# None token) or we're currently connecting so we queue up rows.
self._pending_batches.setdefault(stream_name, []).append(row) self._pending_batches.setdefault(stream_name, []).append(row)
else: else:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
@ -121,9 +119,6 @@ class ReplicationCommandHandler:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return return
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# We're about to go and catch up with the stream, so mark as connecting # We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time by removing stream from # to stop RDATA being handled at the same time by removing stream from
# list of connected streams. We also clear any batched up RDATA from # list of connected streams. We also clear any batched up RDATA from
@ -131,6 +126,9 @@ class ReplicationCommandHandler:
self._streams_connected.discard(cmd.stream_name) self._streams_connected.discard(cmd.stream_name)
self._pending_batches.clear() self._pending_batches.clear()
# We protect catching up with a linearizer in case the replicaiton
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get( current_token = self._replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name cmd.stream_name
@ -158,14 +156,41 @@ class ReplicationCommandHandler:
# We've now caught up to position sent to us, notify handler. # We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
self._streams_connected.add(cmd.stream_name)
# Handle any RDATA that came in while we were catching up. # Handle any RDATA that came in while we were catching up.
rows = self._pending_batches.pop(cmd.stream_name, []) rows = self._pending_batches.pop(cmd.stream_name, [])
if rows: if rows:
await self._replication_data_handler.on_rdata( # We need to make sure we filter out RDATA rows with a token less
cmd.stream_name, rows[-1].token, rows # than what we've caught up to. This is slightly fiddly because of
) # "batched" rows which have a `None` token, indicating that they
# have the same token as the next row with a non-None token.
#
# We do this by walking the list backwards, first removing any RDATA
# rows that are part of an uncompeted batch, then taking rows while
# their token is either None or greater than where we've caught up
# to.
uncompleted_batch = []
unfinished_batch = True
filtered_rows = []
for row in reversed(rows):
if row.token is not None:
unfinished_batch = False
if cmd.token < row.token:
filtered_rows.append(row)
else:
break
elif unfinished_batch:
uncompleted_batch.append(row)
else:
filtered_rows.append(row)
self._streams_connected.add(cmd.stream_name) filtered_rows.reverse()
uncompleted_batch.reverse()
if uncompleted_batch:
self._pending_batches[cmd.stream_name] = uncompleted_batch
await self.on_rdata(cmd.stream_name, rows[-1].token, filtered_rows)
async def on_SYNC(self, cmd: SyncCommand): async def on_SYNC(self, cmd: SyncCommand):
pass pass