Allow ReplicationCommandHandler to have multiple connections

This will allow the server replication component to use
ReplicationCommandHandler.
pull/7187/head
Erik Johnston 2020-03-31 13:53:08 +01:00
parent 5016b162fc
commit e6c25e0858
2 changed files with 29 additions and 19 deletions

View File

@ -69,8 +69,8 @@ class ReplicationCommandHandler:
# The factory used to create connections.
self._factory = None # type: Optional[ReplicationClientFactory]
# The current connection. None if we are currently (re)connecting
self._connection = None
# The currently connected connections.
self._connections = []
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@ -181,29 +181,40 @@ class ReplicationCommandHandler:
"""
return self._presence_handler.get_currently_syncing_users()
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
def new_connection(self, connection):
"""Called when we have a new connection.
"""
self._connection = connection
self._connections.append(connection)
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
# If we're using a ReplicationClientFactory then we reset the connection
# delay now. We don't reset the delay any earlier as otherwise if there
# is a problem during start up we'll end up tight looping connecting to
# the server.
if self._factory:
self._factory.resetDelay()
def lost_connection(self, connection):
"""Called when a connection is closed/lost.
"""
try:
self._connections.remove(connection)
except ValueError:
pass
def connected(self) -> bool:
"""Do we have any replication connections open?
Used to no-op if nothing is connected.
"""
return bool(self._connections)
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self._connection:
self._connection.send_command(cmd)
if self._connections:
for connection in self._connections:
connection.send_command(cmd)
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)

View File

@ -508,8 +508,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
self.handler.finished_connecting()
self.handler.new_connection(self)
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
@ -552,7 +551,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.handler.update_connection(None)
self.handler.lost_connection(self)
# The following simply registers metrics for the replication connections