diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 494c5ddba6..ce14c6f2c4 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -35,6 +35,7 @@ from synapse.replication.tcp.commands import ( UserIpCommand, UserSyncCommand, ) +from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.util.async_helpers import Linearizer @@ -82,7 +83,7 @@ class ReplicationCommandHandler: self._factory = None # type: Optional[ReplicationClientFactory] # The currently connected connections. - self._connections = [] # type: List[Any] + self._connections = [] # type: List[AbstractConnection] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -278,7 +279,7 @@ class ReplicationCommandHandler: """ return self._presence_handler.get_currently_syncing_users() - def new_connection(self, connection): + def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. """ self._connections.append(connection) @@ -295,7 +296,7 @@ class ReplicationCommandHandler: if self._factory: self._factory.resetDelay() - def lost_connection(self, connection): + def lost_connection(self, connection: AbstractConnection): """Called when a connection is closed/lost. """ try: diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d74fde7e0b..bb12d6a14b 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ +import abc import fcntl import logging import struct @@ -485,6 +486,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ + + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection + """ + pass + + +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) + + # The following simply registers metrics for the replication connections pending_commands = LaterGauge(