diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 4f72850543..e03b08371c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,27 +16,12 @@ """ import logging -from typing import Dict, List, Optional +from typing import Dict -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol MYPY = False if MYPY: @@ -56,9 +41,9 @@ class ReplicationClientFactory(ReconnectingClientFactory): initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__(self, hs: "HomeServer", client_name, command_handler): self.client_name = client_name - self.command_handler = ReplicationCommandHandler(hs, handler) + self.command_handler = command_handler self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -87,169 +72,6 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. - - By default proxies incoming replication data to the SlaveStore. - """ - - def __init__(self, hs: "HomeServer"): - self.presence_handler = hs.get_presence_handler() - self.data_handler = hs.get_replication_data_handler() - - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - By default this just pokes the slave store. Can be overridden in subclasses to - handle more. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - logger.debug("Received rdata %s -> %s", stream_name, token) - await self.data_handler.on_rdata(stream_name, token, rows) - - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - await self.data_handler.on_position(stream_name, token) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - - return self.data_handler.get_streams_to_replicate() - - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return self.presence_handler.get_currently_syncing_users() - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") - - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() - - class ReplicationDataHandler: """A replication data handler that calls slave data stores. """ diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 03ec4f1381..0b42339142 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -17,15 +17,22 @@ """ import logging -from typing import Any, Dict, List, Set +from typing import Any, Callable, Dict, List, Optional, Set from prometheus_client import Counter +from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, + RemovePusherCommand, SyncCommand, + UserIpCommand, + UserSyncCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -42,8 +49,9 @@ class ReplicationCommandHandler: """Handles incoming commands from replication. """ - def __init__(self, hs, handler): - self.handler = handler + def __init__(self, hs): + self.replication_data_handler = hs.get_replication_data_handler() + self.presence_handler = hs.get_presence_handler() # Set of streams that we're currently catching up with. self.streams_connecting = set() # type: Set[str] @@ -56,6 +64,22 @@ class ReplicationCommandHandler: # batching works. self.pending_batches = {} # type: Dict[str, List[Any]] + # The factory used to create connections. + self.factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self.connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self.factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self.factory) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -74,7 +98,19 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self.replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, cmd: PositionCommand): stream = self.streams.get(cmd.stream_name) @@ -87,7 +123,9 @@ class ReplicationCommandHandler: self.streams_connecting.add(cmd.stream_name) # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + current_token = self.replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) if current_token is None: logger.warning( "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name @@ -101,39 +139,101 @@ class ReplicationCommandHandler: current_token, cmd.token ) if updates: - await self.handler.on_rdata( + await self.on_rdata( cmd.stream_name, current_token, [stream.parse_row(update[1]) for update in updates], ) # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) + await self.replication_data_handler.on_position(cmd.stream_name, cmd.token) self.streams_connecting.discard(cmd.stream_name) # Handle any RDATA that came in while we were catching up. rows = self.pending_batches.pop(cmd.stream_name, []) if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + await self.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd: SyncCommand): - self.handler.on_sync(cmd.data) + pass async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) + """"Called when get a new REMOTE_SERVER_UP command.""" def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the currently syncing users. (Overriden by the synchrotron's only) """ - return self.handler.get_currently_syncing_users() + return self.presence_handler.get_currently_syncing_users() def update_connection(self, connection): """Called when a connection has been established (or lost with None). """ - return self.handler.update_connection(connection) + self.connection = connection def finished_connecting(self): - return self.handler.finished_connecting() + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self.factory: + self.factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self.connection: + self.connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) diff --git a/synapse/server.py b/synapse/server.py index b828be913c..d54e4cedf8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,10 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool -from synapse.replication.tcp.client import ( - ReplicationClientHandler, - ReplicationDataHandler, -) +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -473,7 +471,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - return ReplicationClientHandler(self) + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 5e6298ab13..9013e9bac9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.set_password import synapse.http.client import synapse.notifier import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -106,7 +107,7 @@ class HomeServer(object): pass def get_tcp_replication( self, - ) -> synapse.replication.tcp.client.ReplicationClientHandler: + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: pass def get_replication_data_handler( self,