Merge server command processing into ReplicationCommandHandler
							parent
							
								
									e6c25e0858
								
							
						
					
					
						commit
						cadb3f57dd
					
				|  | @ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set | |||
| 
 | ||||
| from prometheus_client import Counter | ||||
| 
 | ||||
| from synapse.metrics import LaterGauge | ||||
| from synapse.replication.tcp.client import ReplicationClientFactory | ||||
| from synapse.replication.tcp.commands import ( | ||||
|     ClearUserSyncsCommand, | ||||
|     Command, | ||||
|     FederationAckCommand, | ||||
|     InvalidateCacheCommand, | ||||
|  | @ -28,6 +30,7 @@ from synapse.replication.tcp.commands import ( | |||
|     RdataCommand, | ||||
|     RemoteServerUpCommand, | ||||
|     RemovePusherCommand, | ||||
|     ReplicateCommand, | ||||
|     SyncCommand, | ||||
|     UserIpCommand, | ||||
|     UserSyncCommand, | ||||
|  | @ -42,6 +45,13 @@ logger = logging.getLogger(__name__) | |||
| inbound_rdata_count = Counter( | ||||
|     "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] | ||||
| ) | ||||
| user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") | ||||
| federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") | ||||
| remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") | ||||
| invalidate_cache_counter = Counter( | ||||
|     "synapse_replication_tcp_resource_invalidate_cache", "" | ||||
| ) | ||||
| user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationCommandHandler: | ||||
|  | @ -52,6 +62,8 @@ class ReplicationCommandHandler: | |||
|     def __init__(self, hs): | ||||
|         self._replication_data_handler = hs.get_replication_data_handler() | ||||
|         self._presence_handler = hs.get_presence_handler() | ||||
|         self._store = hs.get_datastore() | ||||
|         self._notifier = hs.get_notifier() | ||||
| 
 | ||||
|         # Set of streams that we've caught up with. | ||||
|         self._streams_connected = set()  # type: Set[str] | ||||
|  | @ -70,7 +82,25 @@ class ReplicationCommandHandler: | |||
|         self._factory = None  # type: Optional[ReplicationClientFactory] | ||||
| 
 | ||||
|         # The currently connected connections. | ||||
|         self._connections = [] | ||||
|         self._connections = []  # type: List[Any] | ||||
| 
 | ||||
|         LaterGauge( | ||||
|             "synapse_replication_tcp_resource_total_connections", | ||||
|             "", | ||||
|             [], | ||||
|             lambda: len(self._connections), | ||||
|         ) | ||||
| 
 | ||||
|         self._is_master = hs.config.worker_app is None | ||||
| 
 | ||||
|         self._federation_sender = None | ||||
|         if self._is_master and not hs.config.send_federation: | ||||
|             self._federation_sender = hs.get_federation_sender() | ||||
| 
 | ||||
|         self._server_notices_sender = None | ||||
|         if self._is_master: | ||||
|             self._server_notices_sender = hs.get_server_notices_sender() | ||||
|             self._notifier.add_remote_server_up_callback(self.send_remote_server_up) | ||||
| 
 | ||||
|     def start_replication(self, hs): | ||||
|         """Helper method to start a replication connection to the remote server | ||||
|  | @ -82,6 +112,73 @@ class ReplicationCommandHandler: | |||
|         port = hs.config.worker_replication_port | ||||
|         hs.get_reactor().connectTCP(host, port, self._factory) | ||||
| 
 | ||||
|     async def on_REPLICATE(self, cmd: ReplicateCommand): | ||||
|         # We only want to announce positions by the writer of the streams. | ||||
|         # Currently this is just the master process. | ||||
|         if not self._is_master: | ||||
|             return | ||||
| 
 | ||||
|         if not self._connections: | ||||
|             raise Exception("Not connected") | ||||
| 
 | ||||
|         for stream_name, stream in self._streams.items(): | ||||
|             current_token = stream.current_token() | ||||
|             self.send_command(PositionCommand(stream_name, current_token)) | ||||
| 
 | ||||
|     async def on_USER_SYNC(self, cmd: UserSyncCommand): | ||||
|         user_sync_counter.inc() | ||||
| 
 | ||||
|         if self._is_master: | ||||
|             await self._presence_handler.update_external_syncs_row( | ||||
|                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms | ||||
|             ) | ||||
| 
 | ||||
|     async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): | ||||
|         if self._is_master: | ||||
|             await self._presence_handler.update_external_syncs_clear(cmd.instance_id) | ||||
| 
 | ||||
|     async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): | ||||
|         federation_ack_counter.inc() | ||||
| 
 | ||||
|         if self._federation_sender: | ||||
|             self._federation_sender.federation_ack(cmd.token) | ||||
| 
 | ||||
|     async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): | ||||
|         remove_pusher_counter.inc() | ||||
| 
 | ||||
|         if self._is_master: | ||||
|             await self._store.delete_pusher_by_app_id_pushkey_user_id( | ||||
|                 app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id | ||||
|             ) | ||||
| 
 | ||||
|             self._notifier.on_new_replication_data() | ||||
| 
 | ||||
|     async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): | ||||
|         invalidate_cache_counter.inc() | ||||
| 
 | ||||
|         if self._is_master: | ||||
|             # We invalidate the cache locally, but then also stream that to other | ||||
|             # workers. | ||||
|             await self._store.invalidate_cache_and_stream( | ||||
|                 cmd.cache_func, tuple(cmd.keys) | ||||
|             ) | ||||
| 
 | ||||
|     async def on_USER_IP(self, cmd: UserIpCommand): | ||||
|         user_ip_cache_counter.inc() | ||||
| 
 | ||||
|         if self._is_master: | ||||
|             await self._store.insert_client_ip( | ||||
|                 cmd.user_id, | ||||
|                 cmd.access_token, | ||||
|                 cmd.ip, | ||||
|                 cmd.user_agent, | ||||
|                 cmd.device_id, | ||||
|                 cmd.last_seen, | ||||
|             ) | ||||
| 
 | ||||
|         if self._server_notices_sender: | ||||
|             await self._server_notices_sender.on_user_ip(cmd.user_id) | ||||
| 
 | ||||
|     async def on_RDATA(self, cmd: RdataCommand): | ||||
|         stream_name = cmd.stream_name | ||||
|         inbound_rdata_count.labels(stream_name).inc() | ||||
|  | @ -174,6 +271,9 @@ class ReplicationCommandHandler: | |||
|         """"Called when get a new REMOTE_SERVER_UP command.""" | ||||
|         self._replication_data_handler.on_remote_server_up(cmd.data) | ||||
| 
 | ||||
|         if self._is_master: | ||||
|             self._notifier.notify_remote_server_up(cmd.data) | ||||
| 
 | ||||
|     def get_currently_syncing_users(self): | ||||
|         """Get the list of currently syncing users (if any). This is called | ||||
|         when a connection has been established and we need to send the | ||||
|  | @ -261,3 +361,10 @@ class ReplicationCommandHandler: | |||
| 
 | ||||
|     def send_remote_server_up(self, server: str): | ||||
|         self.send_command(RemoteServerUpCommand(server)) | ||||
| 
 | ||||
|     def stream_update(self, stream_name: str, token: str, data: Any): | ||||
|         """Called when a new update is available to stream to clients. | ||||
| 
 | ||||
|         We need to check if the client is interested in the stream or not | ||||
|         """ | ||||
|         self.send_command(RdataCommand(stream_name, token, data)) | ||||
|  |  | |||
|  | @ -69,12 +69,8 @@ from synapse.replication.tcp.commands import ( | |||
|     ErrorCommand, | ||||
|     NameCommand, | ||||
|     PingCommand, | ||||
|     PositionCommand, | ||||
|     RdataCommand, | ||||
|     RemoteServerUpCommand, | ||||
|     ReplicateCommand, | ||||
|     ServerCommand, | ||||
|     SyncCommand, | ||||
|     UserSyncCommand, | ||||
| ) | ||||
| from synapse.types import Collection | ||||
|  | @ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
| 
 | ||||
|     max_line_buffer = 10000 | ||||
| 
 | ||||
|     def __init__(self, clock): | ||||
|     def __init__(self, clock, handler): | ||||
|         self.clock = clock | ||||
|         self.handler = handler | ||||
| 
 | ||||
|         self.last_received_command = self.clock.time_msec() | ||||
|         self.last_sent_command = 0 | ||||
|  | @ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
|         # can time us out. | ||||
|         self.send_command(PingCommand(self.clock.time_msec())) | ||||
| 
 | ||||
|         self.handler.new_connection(self) | ||||
| 
 | ||||
|     def send_ping(self): | ||||
|         """Periodically sends a ping and checks if we should close the connection | ||||
|         due to the other side timing out. | ||||
|  | @ -248,8 +247,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
|         Args: | ||||
|             cmd: received command | ||||
|         """ | ||||
|         handler = getattr(self, "on_%s" % (cmd.NAME,)) | ||||
|         await handler(cmd) | ||||
|         handled = False | ||||
| 
 | ||||
|         # First call any command handlers on this instance. These are for TCP | ||||
|         # specific handling. | ||||
|         cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) | ||||
|         if cmd_func: | ||||
|             await cmd_func(cmd) | ||||
|             handled = True | ||||
| 
 | ||||
|         # Then call out to the handler. | ||||
|         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) | ||||
|         if cmd_func: | ||||
|             await cmd_func(cmd) | ||||
|             handled = True | ||||
| 
 | ||||
|         if not handled: | ||||
|             logger.warning("Unhandled command: %r", cmd) | ||||
| 
 | ||||
|     def close(self): | ||||
|         logger.warning("[%s] Closing connection", self.id()) | ||||
|  | @ -378,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
|         self.state = ConnectionStates.CLOSED | ||||
|         self.pending_commands = [] | ||||
| 
 | ||||
|         self.handler.lost_connection(self) | ||||
| 
 | ||||
|         if self.transport: | ||||
|             self.transport.unregisterProducer() | ||||
| 
 | ||||
|  | @ -404,74 +420,19 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|     VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS | ||||
|     VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS | ||||
| 
 | ||||
|     def __init__(self, server_name, clock, streamer): | ||||
|         BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class | ||||
|     def __init__(self, server_name, clock, handler): | ||||
|         BaseReplicationStreamProtocol.__init__(self, clock, handler)  # Old style class | ||||
| 
 | ||||
|         self.server_name = server_name | ||||
|         self.streamer = streamer | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|         self.send_command(ServerCommand(self.server_name)) | ||||
|         BaseReplicationStreamProtocol.connectionMade(self) | ||||
|         self.streamer.new_connection(self) | ||||
| 
 | ||||
|     async def on_NAME(self, cmd): | ||||
|         logger.info("[%s] Renamed to %r", self.id(), cmd.data) | ||||
|         self.name = cmd.data | ||||
| 
 | ||||
|     async def on_USER_SYNC(self, cmd): | ||||
|         await self.streamer.on_user_sync( | ||||
|             cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms | ||||
|         ) | ||||
| 
 | ||||
|     async def on_CLEAR_USER_SYNC(self, cmd): | ||||
|         await self.streamer.on_clear_user_syncs(cmd.instance_id) | ||||
| 
 | ||||
|     async def on_REPLICATE(self, cmd): | ||||
|         # Subscribe to all streams we're publishing to. | ||||
|         for stream_name in self.streamer.streams_by_name: | ||||
|             current_token = self.streamer.get_stream_token(stream_name) | ||||
|             self.send_command(PositionCommand(stream_name, current_token)) | ||||
| 
 | ||||
|     async def on_FEDERATION_ACK(self, cmd): | ||||
|         self.streamer.federation_ack(cmd.token) | ||||
| 
 | ||||
|     async def on_REMOVE_PUSHER(self, cmd): | ||||
|         await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) | ||||
| 
 | ||||
|     async def on_INVALIDATE_CACHE(self, cmd): | ||||
|         await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) | ||||
| 
 | ||||
|     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): | ||||
|         self.streamer.on_remote_server_up(cmd.data) | ||||
| 
 | ||||
|     async def on_USER_IP(self, cmd): | ||||
|         self.streamer.on_user_ip( | ||||
|             cmd.user_id, | ||||
|             cmd.access_token, | ||||
|             cmd.ip, | ||||
|             cmd.user_agent, | ||||
|             cmd.device_id, | ||||
|             cmd.last_seen, | ||||
|         ) | ||||
| 
 | ||||
|     def stream_update(self, stream_name, token, data): | ||||
|         """Called when a new update is available to stream to clients. | ||||
| 
 | ||||
|         We need to check if the client is interested in the stream or not | ||||
|         """ | ||||
|         self.send_command(RdataCommand(stream_name, token, data)) | ||||
| 
 | ||||
|     def send_sync(self, data): | ||||
|         self.send_command(SyncCommand(data)) | ||||
| 
 | ||||
|     def send_remote_server_up(self, server: str): | ||||
|         self.send_command(RemoteServerUpCommand(server)) | ||||
| 
 | ||||
|     def on_connection_closed(self): | ||||
|         BaseReplicationStreamProtocol.on_connection_closed(self) | ||||
|         self.streamer.lost_connection(self) | ||||
| 
 | ||||
| 
 | ||||
| class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | ||||
|     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS | ||||
|  | @ -485,13 +446,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         clock: Clock, | ||||
|         command_handler: "ReplicationCommandHandler", | ||||
|     ): | ||||
|         BaseReplicationStreamProtocol.__init__(self, clock) | ||||
|         BaseReplicationStreamProtocol.__init__(self, clock, command_handler) | ||||
| 
 | ||||
|         self.instance_id = hs.get_instance_id() | ||||
| 
 | ||||
|         self.client_name = client_name | ||||
|         self.server_name = server_name | ||||
|         self.handler = command_handler | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|         self.send_command(NameCommand(self.client_name)) | ||||
|  | @ -507,36 +467,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         for user_id in currently_syncing: | ||||
|             self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) | ||||
| 
 | ||||
|         # We've now finished connecting to so inform the client handler | ||||
|         self.handler.new_connection(self) | ||||
| 
 | ||||
|     async def handle_command(self, cmd: Command): | ||||
|         """Handle a command we have received over the replication stream. | ||||
| 
 | ||||
|         Delegates to `command_handler.on_<COMMAND>`, which must return an | ||||
|         awaitable. | ||||
| 
 | ||||
|         Args: | ||||
|             cmd: received command | ||||
|         """ | ||||
|         handled = False | ||||
| 
 | ||||
|         # First call any command handlers on this instance. These are for TCP | ||||
|         # specific handling. | ||||
|         cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) | ||||
|         if cmd_func: | ||||
|             await cmd_func(cmd) | ||||
|             handled = True | ||||
| 
 | ||||
|         # Then call out to the handler. | ||||
|         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) | ||||
|         if cmd_func: | ||||
|             await cmd_func(cmd) | ||||
|             handled = True | ||||
| 
 | ||||
|         if not handled: | ||||
|             logger.warning("Unhandled command: %r", cmd) | ||||
| 
 | ||||
|     async def on_SERVER(self, cmd): | ||||
|         if cmd.data != self.server_name: | ||||
|             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) | ||||
|  | @ -549,10 +479,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
| 
 | ||||
|         self.send_command(ReplicateCommand()) | ||||
| 
 | ||||
|     def on_connection_closed(self): | ||||
|         BaseReplicationStreamProtocol.on_connection_closed(self) | ||||
|         self.handler.lost_connection(self) | ||||
| 
 | ||||
| 
 | ||||
| # The following simply registers metrics for the replication connections | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ | |||
| 
 | ||||
| import logging | ||||
| import random | ||||
| from typing import Any, Dict, List | ||||
| from typing import Dict | ||||
| 
 | ||||
| from six import itervalues | ||||
| 
 | ||||
|  | @ -25,24 +25,14 @@ from prometheus_client import Counter | |||
| 
 | ||||
| from twisted.internet.protocol import Factory | ||||
| 
 | ||||
| from synapse.metrics import LaterGauge | ||||
| from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.util.metrics import Measure, measure_func | ||||
| 
 | ||||
| from .protocol import ServerReplicationStreamProtocol | ||||
| from .streams import STREAMS_MAP, Stream | ||||
| from .streams.federation import FederationStream | ||||
| from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol | ||||
| from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream | ||||
| from synapse.util.metrics import Measure | ||||
| 
 | ||||
| stream_updates_counter = Counter( | ||||
|     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] | ||||
| ) | ||||
| user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") | ||||
| federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") | ||||
| remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") | ||||
| invalidate_cache_counter = Counter( | ||||
|     "synapse_replication_tcp_resource_invalidate_cache", "" | ||||
| ) | ||||
| user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -52,13 +42,18 @@ class ReplicationStreamProtocolFactory(Factory): | |||
|     """ | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         self.streamer = hs.get_replication_streamer() | ||||
|         self.handler = hs.get_tcp_replication() | ||||
|         self.clock = hs.get_clock() | ||||
|         self.server_name = hs.config.server_name | ||||
|         self.hs = hs | ||||
| 
 | ||||
|         # Ensure the replication streamer is started if we register a | ||||
|         # replication server endpoint. | ||||
|         hs.get_replication_streamer() | ||||
| 
 | ||||
|     def buildProtocol(self, addr): | ||||
|         return ServerReplicationStreamProtocol( | ||||
|             self.server_name, self.clock, self.streamer | ||||
|             self.server_name, self.clock, self.handler | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -78,16 +73,6 @@ class ReplicationStreamer(object): | |||
| 
 | ||||
|         self._replication_torture_level = hs.config.replication_torture_level | ||||
| 
 | ||||
|         # Current connections. | ||||
|         self.connections = []  # type: List[ServerReplicationStreamProtocol] | ||||
| 
 | ||||
|         LaterGauge( | ||||
|             "synapse_replication_tcp_resource_total_connections", | ||||
|             "", | ||||
|             [], | ||||
|             lambda: len(self.connections), | ||||
|         ) | ||||
| 
 | ||||
|         # List of streams that clients can subscribe to. | ||||
|         # We only support federation stream if federation sending hase been | ||||
|         # disabled on the master. | ||||
|  | @ -104,18 +89,12 @@ class ReplicationStreamer(object): | |||
|             self.federation_sender = hs.get_federation_sender() | ||||
| 
 | ||||
|         self.notifier.add_replication_callback(self.on_notifier_poke) | ||||
|         self.notifier.add_remote_server_up_callback(self.send_remote_server_up) | ||||
| 
 | ||||
|         # Keeps track of whether we are currently checking for updates | ||||
|         self.is_looping = False | ||||
|         self.pending_updates = False | ||||
| 
 | ||||
|         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) | ||||
| 
 | ||||
|     def on_shutdown(self): | ||||
|         # close all connections on shutdown | ||||
|         for conn in self.connections: | ||||
|             conn.send_error("server shutting down") | ||||
|         self.client = hs.get_tcp_replication() | ||||
| 
 | ||||
|     def get_streams(self) -> Dict[str, Stream]: | ||||
|         """Get a mapp from stream name to stream instance. | ||||
|  | @ -129,7 +108,7 @@ class ReplicationStreamer(object): | |||
|         This should get called each time new data is available, even if it | ||||
|         is currently being executed, so that nothing gets missed | ||||
|         """ | ||||
|         if not self.connections: | ||||
|         if not self.client.connected(): | ||||
|             # Don't bother if nothing is listening. We still need to advance | ||||
|             # the stream tokens otherwise they'll fall beihind forever | ||||
|             for stream in self.streams: | ||||
|  | @ -186,9 +165,7 @@ class ReplicationStreamer(object): | |||
|                             raise | ||||
| 
 | ||||
|                         logger.debug( | ||||
|                             "Sending %d updates to %d connections", | ||||
|                             len(updates), | ||||
|                             len(self.connections), | ||||
|                             "Sending %d updates", len(updates), | ||||
|                         ) | ||||
| 
 | ||||
|                         if updates: | ||||
|  | @ -204,112 +181,17 @@ class ReplicationStreamer(object): | |||
|                         # token. See RdataCommand for more details. | ||||
|                         batched_updates = _batch_updates(updates) | ||||
| 
 | ||||
|                         for conn in self.connections: | ||||
|                             for token, row in batched_updates: | ||||
|                                 try: | ||||
|                                     conn.stream_update(stream.NAME, token, row) | ||||
|                                 except Exception: | ||||
|                                     logger.exception("Failed to replicate") | ||||
|                         for token, row in batched_updates: | ||||
|                             try: | ||||
|                                 self.client.stream_update(stream.NAME, token, row) | ||||
|                             except Exception: | ||||
|                                 logger.exception("Failed to replicate") | ||||
| 
 | ||||
|             logger.debug("No more pending updates, breaking poke loop") | ||||
|         finally: | ||||
|             self.pending_updates = False | ||||
|             self.is_looping = False | ||||
| 
 | ||||
|     def get_stream_token(self, stream_name): | ||||
|         """For a given stream get all updates since token. This is called when | ||||
|         a client first subscribes to a stream. | ||||
|         """ | ||||
|         stream = self.streams_by_name.get(stream_name, None) | ||||
|         if not stream: | ||||
|             raise Exception("unknown stream %s", stream_name) | ||||
| 
 | ||||
|         return stream.current_token() | ||||
| 
 | ||||
|     @measure_func("repl.federation_ack") | ||||
|     def federation_ack(self, token): | ||||
|         """We've received an ack for federation stream from a client. | ||||
|         """ | ||||
|         federation_ack_counter.inc() | ||||
|         if self.federation_sender: | ||||
|             self.federation_sender.federation_ack(token) | ||||
| 
 | ||||
|     @measure_func("repl.on_user_sync") | ||||
|     async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): | ||||
|         """A client has started/stopped syncing on a worker. | ||||
|         """ | ||||
|         user_sync_counter.inc() | ||||
|         await self.presence_handler.update_external_syncs_row( | ||||
|             instance_id, user_id, is_syncing, last_sync_ms | ||||
|         ) | ||||
| 
 | ||||
|     async def on_clear_user_syncs(self, instance_id): | ||||
|         """A replication client wants us to drop all their UserSync data. | ||||
|         """ | ||||
|         await self.presence_handler.update_external_syncs_clear(instance_id) | ||||
| 
 | ||||
|     @measure_func("repl.on_remove_pusher") | ||||
|     async def on_remove_pusher(self, app_id, push_key, user_id): | ||||
|         """A client has asked us to remove a pusher | ||||
|         """ | ||||
|         remove_pusher_counter.inc() | ||||
|         await self.store.delete_pusher_by_app_id_pushkey_user_id( | ||||
|             app_id=app_id, pushkey=push_key, user_id=user_id | ||||
|         ) | ||||
| 
 | ||||
|         self.notifier.on_new_replication_data() | ||||
| 
 | ||||
|     @measure_func("repl.on_invalidate_cache") | ||||
|     async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): | ||||
|         """The client has asked us to invalidate a cache | ||||
|         """ | ||||
|         invalidate_cache_counter.inc() | ||||
| 
 | ||||
|         # We invalidate the cache locally, but then also stream that to other | ||||
|         # workers. | ||||
|         await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) | ||||
| 
 | ||||
|     @measure_func("repl.on_user_ip") | ||||
|     async def on_user_ip( | ||||
|         self, user_id, access_token, ip, user_agent, device_id, last_seen | ||||
|     ): | ||||
|         """The client saw a user request | ||||
|         """ | ||||
|         user_ip_cache_counter.inc() | ||||
|         await self.store.insert_client_ip( | ||||
|             user_id, access_token, ip, user_agent, device_id, last_seen | ||||
|         ) | ||||
|         await self._server_notices_sender.on_user_ip(user_id) | ||||
| 
 | ||||
|     @measure_func("repl.on_remote_server_up") | ||||
|     def on_remote_server_up(self, server: str): | ||||
|         self.notifier.notify_remote_server_up(server) | ||||
| 
 | ||||
|     def send_remote_server_up(self, server: str): | ||||
|         for conn in self.connections: | ||||
|             conn.send_remote_server_up(server) | ||||
| 
 | ||||
|     def send_sync_to_all_connections(self, data): | ||||
|         """Sends a SYNC command to all clients. | ||||
| 
 | ||||
|         Used in tests. | ||||
|         """ | ||||
|         for conn in self.connections: | ||||
|             conn.send_sync(data) | ||||
| 
 | ||||
|     def new_connection(self, connection): | ||||
|         """A new client connection has been established | ||||
|         """ | ||||
|         self.connections.append(connection) | ||||
| 
 | ||||
|     def lost_connection(self, connection): | ||||
|         """A client connection has been lost | ||||
|         """ | ||||
|         try: | ||||
|             self.connections.remove(connection) | ||||
|         except ValueError: | ||||
|             pass | ||||
| 
 | ||||
| 
 | ||||
| def _batch_updates(updates): | ||||
|     """Takes a list of updates of form [(token, row)] and sets the token to | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston