Move catchup of replication streams to worker. (#7024)
This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.pull/7147/head
							parent
							
								
									7bab642707
								
							
						
					
					
						commit
						4cff617df1
					
				|  | @ -0,0 +1 @@ | |||
| Move catchup of replication streams logic to worker. | ||||
|  | @ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and | |||
| '<' worker to master flows): | ||||
| 
 | ||||
|     > SERVER example.com | ||||
|     < REPLICATE events 53 | ||||
|     < REPLICATE | ||||
|     > POSITION events 53 | ||||
|     > RDATA events 54 ["$foo1:bar.com", ...] | ||||
|     > RDATA events 55 ["$foo4:bar.com", ...] | ||||
| 
 | ||||
| The example shows the server accepting a new connection and sending its | ||||
| identity with the `SERVER` command, followed by the client asking to | ||||
| subscribe to the `events` stream from the token `53`. The server then | ||||
| periodically sends `RDATA` commands which have the format | ||||
| `RDATA <stream_name> <token> <row>`, where the format of `<row>` is | ||||
| defined by the individual streams. | ||||
| The example shows the server accepting a new connection and sending its identity | ||||
| with the `SERVER` command, followed by the client server to respond with the | ||||
| position of all streams. The server then periodically sends `RDATA` commands | ||||
| which have the format `RDATA <stream_name> <token> <row>`, where the format of | ||||
| `<row>` is defined by the individual streams. | ||||
| 
 | ||||
| Error reporting happens by either the client or server sending an ERROR | ||||
| command, and usually the connection will be closed. | ||||
|  | @ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually | |||
| connect to the server using a tool like netcat. A few things should be | ||||
| noted when manually using the protocol: | ||||
| 
 | ||||
| -   When subscribing to a stream using `REPLICATE`, the special token | ||||
|     `NOW` can be used to get all future updates. The special stream name | ||||
|     `ALL` can be used with `NOW` to subscribe to all available streams. | ||||
| -   The federation stream is only available if federation sending has | ||||
|     been disabled on the main process. | ||||
| -   The server will only time connections out that have sent a `PING` | ||||
|  | @ -91,9 +88,7 @@ The client: | |||
| -   Sends a `NAME` command, allowing the server to associate a human | ||||
|     friendly name with the connection. This is optional. | ||||
| -   Sends a `PING` as above | ||||
| -   For each stream the client wishes to subscribe to it sends a | ||||
|     `REPLICATE` with the `stream_name` and token it wants to subscribe | ||||
|     from. | ||||
| -   Sends a `REPLICATE` to get the current position of all streams. | ||||
| -   On receipt of a `SERVER` command, checks that the server name | ||||
|     matches the expected server name. | ||||
| 
 | ||||
|  | @ -140,9 +135,7 @@ the wire: | |||
|     > PING 1490197665618 | ||||
|     < NAME synapse.app.appservice | ||||
|     < PING 1490197665618 | ||||
|     < REPLICATE events 1 | ||||
|     < REPLICATE backfill 1 | ||||
|     < REPLICATE caches 1 | ||||
|     < REPLICATE | ||||
|     > POSITION events 1 | ||||
|     > POSITION backfill 1 | ||||
|     > POSITION caches 1 | ||||
|  | @ -181,9 +174,9 @@ client (C): | |||
| 
 | ||||
| #### POSITION (S) | ||||
| 
 | ||||
|    The position of the stream has been updated. Sent to the client | ||||
|     after all missing updates for a stream have been sent to the client | ||||
|     and they're now up to date. | ||||
|    On receipt of a POSITION command clients should check if they have missed any | ||||
|    updates, and if so then fetch them out of band. Sent in response to a | ||||
|    REPLICATE command (but can happen at any time). | ||||
| 
 | ||||
| #### ERROR (S, C) | ||||
| 
 | ||||
|  | @ -199,20 +192,7 @@ client (C): | |||
| 
 | ||||
| #### REPLICATE (C) | ||||
| 
 | ||||
| Asks the server to replicate a given stream. The syntax is: | ||||
| 
 | ||||
| ``` | ||||
|     REPLICATE <stream_name> <token> | ||||
| ``` | ||||
| 
 | ||||
| Where `<token>` may be either: | ||||
|  * a numeric stream_id to stream updates since (exclusive) | ||||
|  * `NOW` to stream all subsequent updates. | ||||
| 
 | ||||
| The `<stream_name>` is the name of a replication stream to subscribe | ||||
| to (see [here](../synapse/replication/tcp/streams/_base.py) for a list | ||||
| of streams). It can also be `ALL` to subscribe to all known streams, | ||||
| in which case the `<token>` must be set to `NOW`. | ||||
| Asks the server for the current position of all streams. | ||||
| 
 | ||||
| #### USER_SYNC (C) | ||||
| 
 | ||||
|  |  | |||
|  | @ -401,6 +401,9 @@ class GenericWorkerTyping(object): | |||
|             self._room_serials[row.room_id] = token | ||||
|             self._room_typing[row.room_id] = row.user_ids | ||||
| 
 | ||||
|     def get_current_token(self) -> int: | ||||
|         return self._latest_room_serial | ||||
| 
 | ||||
| 
 | ||||
| class GenericWorkerSlavedStore( | ||||
|     # FIXME(#3714): We need to add UserDirectoryStore as we write directly | ||||
|  |  | |||
|  | @ -499,4 +499,13 @@ class FederationSender(object): | |||
|         self._get_per_destination_queue(destination).attempt_new_transaction() | ||||
| 
 | ||||
|     def get_current_token(self) -> int: | ||||
|         # Dummy implementation for case where federation sender isn't offloaded | ||||
|         # to a worker. | ||||
|         return 0 | ||||
| 
 | ||||
|     async def get_replication_rows( | ||||
|         self, from_token, to_token, limit, federation_ack=None | ||||
|     ): | ||||
|         # Dummy implementation for case where federation sender isn't offloaded | ||||
|         # to a worker. | ||||
|         return [] | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ from synapse.replication.http import ( | |||
|     membership, | ||||
|     register, | ||||
|     send_event, | ||||
|     streams, | ||||
| ) | ||||
| 
 | ||||
| REPLICATION_PREFIX = "/_synapse/replication" | ||||
|  | @ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource): | |||
|         login.register_servlets(hs, self) | ||||
|         register.register_servlets(hs, self) | ||||
|         devices.register_servlets(hs, self) | ||||
|         streams.register_servlets(hs, self) | ||||
|  |  | |||
|  | @ -0,0 +1,78 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2020 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.http.servlet import parse_integer | ||||
| from synapse.replication.http._base import ReplicationEndpoint | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class ReplicationGetStreamUpdates(ReplicationEndpoint): | ||||
|     """Fetches stream updates from a server. Used for streams not persisted to | ||||
|     the database, e.g. typing notifications. | ||||
| 
 | ||||
|     The API looks like: | ||||
| 
 | ||||
|         GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 | ||||
| 
 | ||||
|         200 OK | ||||
| 
 | ||||
|         { | ||||
|             updates: [ ... ], | ||||
|             upto_token: 10, | ||||
|             limited: False, | ||||
|         } | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     NAME = "get_repl_stream_updates" | ||||
|     PATH_ARGS = ("stream_name",) | ||||
|     METHOD = "GET" | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super().__init__(hs) | ||||
| 
 | ||||
|         # We pull the streams from the replication steamer (if we try and make | ||||
|         # them ourselves we end up in an import loop). | ||||
|         self.streams = hs.get_replication_streamer().get_streams() | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _serialize_payload(stream_name, from_token, upto_token, limit): | ||||
|         return {"from_token": from_token, "upto_token": upto_token, "limit": limit} | ||||
| 
 | ||||
|     async def _handle_request(self, request, stream_name): | ||||
|         stream = self.streams.get(stream_name) | ||||
|         if stream is None: | ||||
|             raise SynapseError(400, "Unknown stream") | ||||
| 
 | ||||
|         from_token = parse_integer(request, "from_token", required=True) | ||||
|         upto_token = parse_integer(request, "upto_token", required=True) | ||||
|         limit = parse_integer(request, "limit", required=True) | ||||
| 
 | ||||
|         updates, upto_token, limited = await stream.get_updates_since( | ||||
|             from_token, upto_token, limit | ||||
|         ) | ||||
| 
 | ||||
|         return ( | ||||
|             200, | ||||
|             {"updates": updates, "upto_token": upto_token, "limited": limited}, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     ReplicationGetStreamUpdates(hs).register(http_server) | ||||
|  | @ -18,8 +18,10 @@ from typing import Dict, Optional | |||
| 
 | ||||
| import six | ||||
| 
 | ||||
| from synapse.storage._base import SQLBaseStore | ||||
| from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME | ||||
| from synapse.storage.data_stores.main.cache import ( | ||||
|     CURRENT_STATE_CACHE_NAME, | ||||
|     CacheInvalidationWorkerStore, | ||||
| ) | ||||
| from synapse.storage.database import Database | ||||
| from synapse.storage.engines import PostgresEngine | ||||
| 
 | ||||
|  | @ -35,7 +37,7 @@ def __func__(inp): | |||
|         return inp.__func__ | ||||
| 
 | ||||
| 
 | ||||
| class BaseSlavedStore(SQLBaseStore): | ||||
| class BaseSlavedStore(CacheInvalidationWorkerStore): | ||||
|     def __init__(self, database: Database, db_conn, hs): | ||||
|         super(BaseSlavedStore, self).__init__(database, db_conn, hs) | ||||
|         if isinstance(self.database_engine, PostgresEngine): | ||||
|  | @ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore): | |||
|             pos["caches"] = self._cache_id_gen.get_current_token() | ||||
|         return pos | ||||
| 
 | ||||
|     def get_cache_stream_token(self): | ||||
|         if self._cache_id_gen: | ||||
|             return self._cache_id_gen.get_current_token() | ||||
|         else: | ||||
|             return 0 | ||||
| 
 | ||||
|     def process_replication_rows(self, stream_name, token, rows): | ||||
|         if stream_name == "caches": | ||||
|             if self._cache_id_gen: | ||||
|  |  | |||
|  | @ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): | |||
|         result["pushers"] = self._pushers_id_gen.get_current_token() | ||||
|         return result | ||||
| 
 | ||||
|     def get_pushers_stream_token(self): | ||||
|         return self._pushers_id_gen.get_current_token() | ||||
| 
 | ||||
|     def process_replication_rows(self, stream_name, token, rows): | ||||
|         if stream_name == "pushers": | ||||
|             self._pushers_id_gen.advance(token) | ||||
|  |  | |||
|  | @ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): | |||
|         self.client_name = client_name | ||||
|         self.handler = handler | ||||
|         self.server_name = hs.config.server_name | ||||
|         self.hs = hs | ||||
|         self._clock = hs.get_clock()  # As self.clock is defined in super class | ||||
| 
 | ||||
|         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) | ||||
|  | @ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): | |||
|     def buildProtocol(self, addr): | ||||
|         logger.info("Connected to replication: %r", addr) | ||||
|         return ClientReplicationStreamProtocol( | ||||
|             self.client_name, self.server_name, self._clock, self.handler | ||||
|             self.hs, self.client_name, self.server_name, self._clock, self.handler, | ||||
|         ) | ||||
| 
 | ||||
|     def clientConnectionLost(self, connector, reason): | ||||
|  |  | |||
|  | @ -136,8 +136,8 @@ class PositionCommand(Command): | |||
|     """Sent by the server to tell the client the stream postition without | ||||
|     needing to send an RDATA. | ||||
| 
 | ||||
|     Sent to the client after all missing updates for a stream have been sent | ||||
|     to the client and they're now up to date. | ||||
|     On receipt of a POSITION command clients should check if they have missed | ||||
|     any updates, and if so then fetch them out of band. | ||||
|     """ | ||||
| 
 | ||||
|     NAME = "POSITION" | ||||
|  | @ -179,42 +179,24 @@ class NameCommand(Command): | |||
| 
 | ||||
| 
 | ||||
| class ReplicateCommand(Command): | ||||
|     """Sent by the client to subscribe to the stream. | ||||
|     """Sent by the client to subscribe to streams. | ||||
| 
 | ||||
|     Format:: | ||||
| 
 | ||||
|         REPLICATE <stream_name> <token> | ||||
| 
 | ||||
|     Where <token> may be either: | ||||
|         * a numeric stream_id to stream updates from | ||||
|         * "NOW" to stream all subsequent updates. | ||||
| 
 | ||||
|     The <stream_name> can be "ALL" to subscribe to all known streams, in which | ||||
|     case the <token> must be set to "NOW", i.e.:: | ||||
| 
 | ||||
|         REPLICATE ALL NOW | ||||
|         REPLICATE | ||||
|     """ | ||||
| 
 | ||||
|     NAME = "REPLICATE" | ||||
| 
 | ||||
|     def __init__(self, stream_name, token): | ||||
|         self.stream_name = stream_name | ||||
|         self.token = token | ||||
|     def __init__(self): | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_line(cls, line): | ||||
|         stream_name, token = line.split(" ", 1) | ||||
|         if token in ("NOW", "now"): | ||||
|             token = "NOW" | ||||
|         else: | ||||
|             token = int(token) | ||||
|         return cls(stream_name, token) | ||||
|         return cls() | ||||
| 
 | ||||
|     def to_line(self): | ||||
|         return " ".join((self.stream_name, str(self.token))) | ||||
| 
 | ||||
|     def get_logcontext_id(self): | ||||
|         return "REPLICATE-" + self.stream_name | ||||
|         return "" | ||||
| 
 | ||||
| 
 | ||||
| class UserSyncCommand(Command): | ||||
|  |  | |||
|  | @ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: | |||
|     > PING 1490197665618 | ||||
|     < NAME synapse.app.appservice | ||||
|     < PING 1490197665618 | ||||
|     < REPLICATE events 1 | ||||
|     < REPLICATE backfill 1 | ||||
|     < REPLICATE caches 1 | ||||
|     < REPLICATE | ||||
|     > POSITION events 1 | ||||
|     > POSITION backfill 1 | ||||
|     > POSITION caches 1 | ||||
|  | @ -53,17 +51,15 @@ import fcntl | |||
| import logging | ||||
| import struct | ||||
| from collections import defaultdict | ||||
| from typing import Any, DefaultDict, Dict, List, Set, Tuple | ||||
| from typing import Any, DefaultDict, Dict, List, Set | ||||
| 
 | ||||
| from six import iteritems, iterkeys | ||||
| from six import iteritems | ||||
| 
 | ||||
| from prometheus_client import Counter | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.protocols.basic import LineOnlyReceiver | ||||
| from twisted.python.failure import Failure | ||||
| 
 | ||||
| from synapse.logging.context import make_deferred_yieldable, run_in_background | ||||
| from synapse.metrics import LaterGauge | ||||
| from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.replication.tcp.commands import ( | ||||
|  | @ -82,11 +78,16 @@ from synapse.replication.tcp.commands import ( | |||
|     SyncCommand, | ||||
|     UserSyncCommand, | ||||
| ) | ||||
| from synapse.replication.tcp.streams import STREAMS_MAP | ||||
| from synapse.replication.tcp.streams import STREAMS_MAP, Stream | ||||
| from synapse.types import Collection | ||||
| from synapse.util import Clock | ||||
| from synapse.util.stringutils import random_string | ||||
| 
 | ||||
| MYPY = False | ||||
| if MYPY: | ||||
|     from synapse.server import HomeServer | ||||
| 
 | ||||
| 
 | ||||
| connection_close_counter = Counter( | ||||
|     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] | ||||
| ) | ||||
|  | @ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         self.server_name = server_name | ||||
|         self.streamer = streamer | ||||
| 
 | ||||
|         # The streams the client has subscribed to and is up to date with | ||||
|         self.replication_streams = set()  # type: Set[str] | ||||
| 
 | ||||
|         # The streams the client is currently subscribing to. | ||||
|         self.connecting_streams = set()  # type:  Set[str] | ||||
| 
 | ||||
|         # Map from stream name to list of updates to send once we've finished | ||||
|         # subscribing the client to the stream. | ||||
|         self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]] | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|         self.send_command(ServerCommand(self.server_name)) | ||||
|         BaseReplicationStreamProtocol.connectionMade(self) | ||||
|  | @ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         ) | ||||
| 
 | ||||
|     async def on_REPLICATE(self, cmd): | ||||
|         stream_name = cmd.stream_name | ||||
|         token = cmd.token | ||||
| 
 | ||||
|         if stream_name == "ALL": | ||||
|             # Subscribe to all streams we're publishing to. | ||||
|             deferreds = [ | ||||
|                 run_in_background(self.subscribe_to_stream, stream, token) | ||||
|                 for stream in iterkeys(self.streamer.streams_by_name) | ||||
|             ] | ||||
| 
 | ||||
|             await make_deferred_yieldable( | ||||
|                 defer.gatherResults(deferreds, consumeErrors=True) | ||||
|             ) | ||||
|         else: | ||||
|             await self.subscribe_to_stream(stream_name, token) | ||||
|         # Subscribe to all streams we're publishing to. | ||||
|         for stream_name in self.streamer.streams_by_name: | ||||
|             current_token = self.streamer.get_stream_token(stream_name) | ||||
|             self.send_command(PositionCommand(stream_name, current_token)) | ||||
| 
 | ||||
|     async def on_FEDERATION_ACK(self, cmd): | ||||
|         self.streamer.federation_ack(cmd.token) | ||||
|  | @ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             cmd.last_seen, | ||||
|         ) | ||||
| 
 | ||||
|     async def subscribe_to_stream(self, stream_name, token): | ||||
|         """Subscribe the remote to a stream. | ||||
| 
 | ||||
|         This invloves checking if they've missed anything and sending those | ||||
|         updates down if they have. During that time new updates for the stream | ||||
|         are queued and sent once we've sent down any missed updates. | ||||
|         """ | ||||
|         self.replication_streams.discard(stream_name) | ||||
|         self.connecting_streams.add(stream_name) | ||||
| 
 | ||||
|         try: | ||||
|             # Get missing updates | ||||
|             updates, current_token = await self.streamer.get_stream_updates( | ||||
|                 stream_name, token | ||||
|             ) | ||||
| 
 | ||||
|             # Send all the missing updates | ||||
|             for update in updates: | ||||
|                 token, row = update[0], update[1] | ||||
|                 self.send_command(RdataCommand(stream_name, token, row)) | ||||
| 
 | ||||
|             # We send a POSITION command to ensure that they have an up to | ||||
|             # date token (especially useful if we didn't send any updates | ||||
|             # above) | ||||
|             self.send_command(PositionCommand(stream_name, current_token)) | ||||
| 
 | ||||
|             # Now we can send any updates that came in while we were subscribing | ||||
|             pending_rdata = self.pending_rdata.pop(stream_name, []) | ||||
|             updates = [] | ||||
|             for token, update in pending_rdata: | ||||
|                 # If the token is null, it is part of a batch update. Batches | ||||
|                 # are multiple updates that share a single token. To denote | ||||
|                 # this, the token is set to None for all tokens in the batch | ||||
|                 # except for the last. If we find a None token, we keep looking | ||||
|                 # through tokens until we find one that is not None and then | ||||
|                 # process all previous updates in the batch as if they had the | ||||
|                 # final token. | ||||
|                 if token is None: | ||||
|                     # Store this update as part of a batch | ||||
|                     updates.append(update) | ||||
|                     continue | ||||
| 
 | ||||
|                 if token <= current_token: | ||||
|                     # This update or batch of updates is older than | ||||
|                     # current_token, dismiss it | ||||
|                     updates = [] | ||||
|                     continue | ||||
| 
 | ||||
|                 updates.append(update) | ||||
| 
 | ||||
|                 # Send all updates that are part of this batch with the | ||||
|                 # found token | ||||
|                 for update in updates: | ||||
|                     self.send_command(RdataCommand(stream_name, token, update)) | ||||
| 
 | ||||
|                 # Clear stored updates | ||||
|                 updates = [] | ||||
| 
 | ||||
|             # They're now fully subscribed | ||||
|             self.replication_streams.add(stream_name) | ||||
|         except Exception as e: | ||||
|             logger.exception("[%s] Failed to handle REPLICATE command", self.id()) | ||||
|             self.send_error("failed to handle replicate: %r", e) | ||||
|         finally: | ||||
|             self.connecting_streams.discard(stream_name) | ||||
| 
 | ||||
|     def stream_update(self, stream_name, token, data): | ||||
|         """Called when a new update is available to stream to clients. | ||||
| 
 | ||||
|         We need to check if the client is interested in the stream or not | ||||
|         """ | ||||
|         if stream_name in self.replication_streams: | ||||
|             # The client is subscribed to the stream | ||||
|             self.send_command(RdataCommand(stream_name, token, data)) | ||||
|         elif stream_name in self.connecting_streams: | ||||
|             # The client is being subscribed to the stream | ||||
|             logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) | ||||
|             self.pending_rdata.setdefault(stream_name, []).append((token, data)) | ||||
|         else: | ||||
|             # The client isn't subscribed | ||||
|             logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) | ||||
|         self.send_command(RdataCommand(stream_name, token, data)) | ||||
| 
 | ||||
|     def send_sync(self, data): | ||||
|         self.send_command(SyncCommand(data)) | ||||
|  | @ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         hs: "HomeServer", | ||||
|         client_name: str, | ||||
|         server_name: str, | ||||
|         clock: Clock, | ||||
|  | @ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         self.server_name = server_name | ||||
|         self.handler = handler | ||||
| 
 | ||||
|         self.streams = { | ||||
|             stream.NAME: stream(hs) for stream in STREAMS_MAP.values() | ||||
|         }  # type: Dict[str, Stream] | ||||
| 
 | ||||
|         # Set of stream names that have been subscribe to, but haven't yet | ||||
|         # caught up with. This is used to track when the client has been fully | ||||
|         # connected to the remote. | ||||
|         self.streams_connecting = set()  # type: Set[str] | ||||
|         self.streams_connecting = set(STREAMS_MAP)  # type: Set[str] | ||||
| 
 | ||||
|         # Map of stream to batched updates. See RdataCommand for info on how | ||||
|         # batching works. | ||||
|         self.pending_batches = {}  # type: Dict[str, Any] | ||||
|         self.pending_batches = {}  # type: Dict[str, List[Any]] | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|         self.send_command(NameCommand(self.client_name)) | ||||
|         BaseReplicationStreamProtocol.connectionMade(self) | ||||
| 
 | ||||
|         # Once we've connected subscribe to the necessary streams | ||||
|         for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): | ||||
|             self.replicate(stream_name, token) | ||||
|         self.replicate() | ||||
| 
 | ||||
|         # Tell the server if we have any users currently syncing (should only | ||||
|         # happen on synchrotrons) | ||||
|  | @ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         # We've now finished connecting to so inform the client handler | ||||
|         self.handler.update_connection(self) | ||||
| 
 | ||||
|         # This will happen if we don't actually subscribe to any streams | ||||
|         if not self.streams_connecting: | ||||
|             self.handler.finished_connecting() | ||||
| 
 | ||||
|     async def on_SERVER(self, cmd): | ||||
|         if cmd.data != self.server_name: | ||||
|             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) | ||||
|  | @ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             ) | ||||
|             raise | ||||
| 
 | ||||
|         if cmd.token is None: | ||||
|         if cmd.token is None or stream_name in self.streams_connecting: | ||||
|             # I.e. this is part of a batch of updates for this stream. Batch | ||||
|             # until we get an update for the stream with a non None token | ||||
|             self.pending_batches.setdefault(stream_name, []).append(row) | ||||
|  | @ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             rows.append(row) | ||||
|             await self.handler.on_rdata(stream_name, cmd.token, rows) | ||||
| 
 | ||||
|     async def on_POSITION(self, cmd): | ||||
|         # When we get a `POSITION` command it means we've finished getting | ||||
|         # missing updates for the given stream, and are now up to date. | ||||
|     async def on_POSITION(self, cmd: PositionCommand): | ||||
|         stream = self.streams.get(cmd.stream_name) | ||||
|         if not stream: | ||||
|             logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) | ||||
|             return | ||||
| 
 | ||||
|         # Find where we previously streamed up to. | ||||
|         current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) | ||||
|         if current_token is None: | ||||
|             logger.warning( | ||||
|                 "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name | ||||
|             ) | ||||
|             return | ||||
| 
 | ||||
|         # Fetch all updates between then and now. | ||||
|         limited = True | ||||
|         while limited: | ||||
|             updates, current_token, limited = await stream.get_updates_since( | ||||
|                 current_token, cmd.token | ||||
|             ) | ||||
| 
 | ||||
|             # Check if the connection was closed underneath us, if so we bail | ||||
|             # rather than risk having concurrent catch ups going on. | ||||
|             if self.state == ConnectionStates.CLOSED: | ||||
|                 return | ||||
| 
 | ||||
|             if updates: | ||||
|                 await self.handler.on_rdata( | ||||
|                     cmd.stream_name, | ||||
|                     current_token, | ||||
|                     [stream.parse_row(update[1]) for update in updates], | ||||
|                 ) | ||||
| 
 | ||||
|         # We've now caught up to position sent to us, notify handler. | ||||
|         await self.handler.on_position(cmd.stream_name, cmd.token) | ||||
| 
 | ||||
|         self.streams_connecting.discard(cmd.stream_name) | ||||
|         if not self.streams_connecting: | ||||
|             self.handler.finished_connecting() | ||||
| 
 | ||||
|         await self.handler.on_position(cmd.stream_name, cmd.token) | ||||
|         # Check if the connection was closed underneath us, if so we bail | ||||
|         # rather than risk having concurrent catch ups going on. | ||||
|         if self.state == ConnectionStates.CLOSED: | ||||
|             return | ||||
| 
 | ||||
|         # Handle any RDATA that came in while we were catching up. | ||||
|         rows = self.pending_batches.pop(cmd.stream_name, []) | ||||
|         if rows: | ||||
|             await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) | ||||
| 
 | ||||
|     async def on_SYNC(self, cmd): | ||||
|         self.handler.on_sync(cmd.data) | ||||
|  | @ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): | ||||
|         self.handler.on_remote_server_up(cmd.data) | ||||
| 
 | ||||
|     def replicate(self, stream_name, token): | ||||
|     def replicate(self): | ||||
|         """Send the subscription request to the server | ||||
|         """ | ||||
|         if stream_name not in STREAMS_MAP: | ||||
|             raise Exception("Invalid stream name %r" % (stream_name,)) | ||||
|         logger.info("[%s] Subscribing to replication streams", self.id()) | ||||
| 
 | ||||
|         logger.info( | ||||
|             "[%s] Subscribing to replication stream: %r from %r", | ||||
|             self.id(), | ||||
|             stream_name, | ||||
|             token, | ||||
|         ) | ||||
| 
 | ||||
|         self.streams_connecting.add(stream_name) | ||||
| 
 | ||||
|         self.send_command(ReplicateCommand(stream_name, token)) | ||||
|         self.send_command(ReplicateCommand()) | ||||
| 
 | ||||
|     def on_connection_closed(self): | ||||
|         BaseReplicationStreamProtocol.on_connection_closed(self) | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ | |||
| 
 | ||||
| import logging | ||||
| import random | ||||
| from typing import Any, List | ||||
| from typing import Any, Dict, List | ||||
| 
 | ||||
| from six import itervalues | ||||
| 
 | ||||
|  | @ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process | |||
| from synapse.util.metrics import Measure, measure_func | ||||
| 
 | ||||
| from .protocol import ServerReplicationStreamProtocol | ||||
| from .streams import STREAMS_MAP | ||||
| from .streams import STREAMS_MAP, Stream | ||||
| from .streams.federation import FederationStream | ||||
| 
 | ||||
| stream_updates_counter = Counter( | ||||
|  | @ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): | |||
|     """ | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         self.streamer = ReplicationStreamer(hs) | ||||
|         self.streamer = hs.get_replication_streamer() | ||||
|         self.clock = hs.get_clock() | ||||
|         self.server_name = hs.config.server_name | ||||
| 
 | ||||
|  | @ -133,6 +133,11 @@ class ReplicationStreamer(object): | |||
|         for conn in self.connections: | ||||
|             conn.send_error("server shutting down") | ||||
| 
 | ||||
|     def get_streams(self) -> Dict[str, Stream]: | ||||
|         """Get a mapp from stream name to stream instance. | ||||
|         """ | ||||
|         return self.streams_by_name | ||||
| 
 | ||||
|     def on_notifier_poke(self): | ||||
|         """Checks if there is actually any new data and sends it to the | ||||
|         connections if there are. | ||||
|  | @ -190,7 +195,8 @@ class ReplicationStreamer(object): | |||
|                             stream.current_token(), | ||||
|                         ) | ||||
|                         try: | ||||
|                             updates, current_token = await stream.get_updates() | ||||
|                             updates, current_token, limited = await stream.get_updates() | ||||
|                             self.pending_updates |= limited | ||||
|                         except Exception: | ||||
|                             logger.info("Failed to handle stream %s", stream.NAME) | ||||
|                             raise | ||||
|  | @ -226,8 +232,7 @@ class ReplicationStreamer(object): | |||
|             self.pending_updates = False | ||||
|             self.is_looping = False | ||||
| 
 | ||||
|     @measure_func("repl.get_stream_updates") | ||||
|     async def get_stream_updates(self, stream_name, token): | ||||
|     def get_stream_token(self, stream_name): | ||||
|         """For a given stream get all updates since token. This is called when | ||||
|         a client first subscribes to a stream. | ||||
|         """ | ||||
|  | @ -235,7 +240,7 @@ class ReplicationStreamer(object): | |||
|         if not stream: | ||||
|             raise Exception("unknown stream %s", stream_name) | ||||
| 
 | ||||
|         return await stream.get_updates_since(token) | ||||
|         return stream.current_token() | ||||
| 
 | ||||
|     @measure_func("repl.federation_ack") | ||||
|     def federation_ack(self, token): | ||||
|  |  | |||
|  | @ -24,6 +24,9 @@ Each stream is defined by the following information: | |||
|     current_token:      The function that returns the current token for the stream | ||||
|     update_function:    The function that returns a list of updates between two tokens | ||||
| """ | ||||
| 
 | ||||
| from typing import Dict, Type | ||||
| 
 | ||||
| from synapse.replication.tcp.streams._base import ( | ||||
|     AccountDataStream, | ||||
|     BackfillStream, | ||||
|  | @ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import ( | |||
|     PushersStream, | ||||
|     PushRulesStream, | ||||
|     ReceiptsStream, | ||||
|     Stream, | ||||
|     TagAccountDataStream, | ||||
|     ToDeviceStream, | ||||
|     TypingStream, | ||||
|  | @ -63,10 +67,12 @@ STREAMS_MAP = { | |||
|         GroupServerStream, | ||||
|         UserSignatureStream, | ||||
|     ) | ||||
| } | ||||
| }  # type: Dict[str, Type[Stream]] | ||||
| 
 | ||||
| 
 | ||||
| __all__ = [ | ||||
|     "STREAMS_MAP", | ||||
|     "Stream", | ||||
|     "BackfillStream", | ||||
|     "PresenceStream", | ||||
|     "TypingStream", | ||||
|  |  | |||
|  | @ -14,13 +14,13 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import itertools | ||||
| import logging | ||||
| from collections import namedtuple | ||||
| from typing import Any, List, Optional, Tuple | ||||
| from typing import Any, Awaitable, Callable, List, Optional, Tuple | ||||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from synapse.replication.http.streams import ReplicationGetStreamUpdates | ||||
| from synapse.types import JsonDict | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -29,6 +29,15 @@ logger = logging.getLogger(__name__) | |||
| MAX_EVENTS_BEHIND = 500000 | ||||
| 
 | ||||
| 
 | ||||
| # Some type aliases to make things a bit easier. | ||||
| 
 | ||||
| # A stream position token | ||||
| Token = int | ||||
| 
 | ||||
| # A pair of position in stream and args used to create an instance of `ROW_TYPE`. | ||||
| StreamRow = Tuple[Token, tuple] | ||||
| 
 | ||||
| 
 | ||||
| class Stream(object): | ||||
|     """Base class for the streams. | ||||
| 
 | ||||
|  | @ -56,6 +65,7 @@ class Stream(object): | |||
|         return cls.ROW_TYPE(*row) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
| 
 | ||||
|         # The token from which we last asked for updates | ||||
|         self.last_token = self.current_token() | ||||
| 
 | ||||
|  | @ -65,61 +75,46 @@ class Stream(object): | |||
|         """ | ||||
|         self.last_token = self.current_token() | ||||
| 
 | ||||
|     async def get_updates(self): | ||||
|     async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: | ||||
|         """Gets all updates since the last time this function was called (or | ||||
|         since the stream was constructed if it hadn't been called before). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[Tuple[List[Tuple[int, Any]], int]: | ||||
|                 Resolves to a pair ``(updates, current_token)``, where ``updates`` is a | ||||
|                 list of ``(token, row)`` entries. ``row`` will be json-serialised and | ||||
|                 sent over the replication steam. | ||||
|             A triplet `(updates, new_last_token, limited)`, where `updates` is | ||||
|             a list of `(token, row)` entries, `new_last_token` is the new | ||||
|             position in stream, and `limited` is whether there are more updates | ||||
|             to fetch. | ||||
|         """ | ||||
|         updates, current_token = await self.get_updates_since(self.last_token) | ||||
|         current_token = self.current_token() | ||||
|         updates, current_token, limited = await self.get_updates_since( | ||||
|             self.last_token, current_token | ||||
|         ) | ||||
|         self.last_token = current_token | ||||
| 
 | ||||
|         return updates, current_token | ||||
|         return updates, current_token, limited | ||||
| 
 | ||||
|     async def get_updates_since( | ||||
|         self, from_token: int | ||||
|     ) -> Tuple[List[Tuple[int, JsonDict]], int]: | ||||
|         self, from_token: Token, upto_token: Token, limit: int = 100 | ||||
|     ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: | ||||
|         """Like get_updates except allows specifying from when we should | ||||
|         stream updates | ||||
| 
 | ||||
|         Returns: | ||||
|             Resolves to a pair `(updates, new_last_token)`, where `updates` is | ||||
|             a list of `(token, row)` entries and `new_last_token` is the new | ||||
|             position in stream. | ||||
|             A triplet `(updates, new_last_token, limited)`, where `updates` is | ||||
|             a list of `(token, row)` entries, `new_last_token` is the new | ||||
|             position in stream, and `limited` is whether there are more updates | ||||
|             to fetch. | ||||
|         """ | ||||
| 
 | ||||
|         if from_token in ("NOW", "now"): | ||||
|             return [], self.current_token() | ||||
| 
 | ||||
|         current_token = self.current_token() | ||||
| 
 | ||||
|         from_token = int(from_token) | ||||
| 
 | ||||
|         if from_token == current_token: | ||||
|             return [], current_token | ||||
|         if from_token == upto_token: | ||||
|             return [], upto_token, False | ||||
| 
 | ||||
|         rows = await self.update_function( | ||||
|             from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 | ||||
|         updates, upto_token, limited = await self.update_function( | ||||
|             from_token, upto_token, limit=limit, | ||||
|         ) | ||||
| 
 | ||||
|         # never turn more than MAX_EVENTS_BEHIND + 1 into updates. | ||||
|         rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) | ||||
| 
 | ||||
|         updates = [(row[0], row[1:]) for row in rows] | ||||
| 
 | ||||
|         # check we didn't get more rows than the limit. | ||||
|         # doing it like this allows the update_function to be a generator. | ||||
|         if len(updates) >= MAX_EVENTS_BEHIND: | ||||
|             raise Exception("stream %s has fallen behind" % (self.NAME)) | ||||
| 
 | ||||
|         # The update function didn't hit the limit, so we must have got all | ||||
|         # the updates to `current_token`, and can return that as our new | ||||
|         # stream position. | ||||
|         return updates, current_token | ||||
|         return updates, upto_token, limited | ||||
| 
 | ||||
|     def current_token(self): | ||||
|         """Gets the current token of the underlying streams. Should be provided | ||||
|  | @ -141,6 +136,48 @@ class Stream(object): | |||
|         raise NotImplementedError() | ||||
| 
 | ||||
| 
 | ||||
| def db_query_to_update_function( | ||||
|     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] | ||||
| ) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: | ||||
|     """Wraps a db query function which returns a list of rows to make it | ||||
|     suitable for use as an `update_function` for the Stream class | ||||
|     """ | ||||
| 
 | ||||
|     async def update_function(from_token, upto_token, limit): | ||||
|         rows = await query_function(from_token, upto_token, limit) | ||||
|         updates = [(row[0], row[1:]) for row in rows] | ||||
|         limited = False | ||||
|         if len(updates) == limit: | ||||
|             upto_token = rows[-1][0] | ||||
|             limited = True | ||||
| 
 | ||||
|         return updates, upto_token, limited | ||||
| 
 | ||||
|     return update_function | ||||
| 
 | ||||
| 
 | ||||
| def make_http_update_function( | ||||
|     hs, stream_name: str | ||||
| ) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: | ||||
|     """Makes a suitable function for use as an `update_function` that queries | ||||
|     the master process for updates. | ||||
|     """ | ||||
| 
 | ||||
|     client = ReplicationGetStreamUpdates.make_client(hs) | ||||
| 
 | ||||
|     async def update_function( | ||||
|         from_token: int, upto_token: int, limit: int | ||||
|     ) -> Tuple[List[Tuple[int, tuple]], int, bool]: | ||||
|         return await client( | ||||
|             stream_name=stream_name, | ||||
|             from_token=from_token, | ||||
|             upto_token=upto_token, | ||||
|             limit=limit, | ||||
|         ) | ||||
| 
 | ||||
|     return update_function | ||||
| 
 | ||||
| 
 | ||||
| class BackfillStream(Stream): | ||||
|     """We fetched some old events and either we had never seen that event before | ||||
|     or it went from being an outlier to not. | ||||
|  | @ -164,7 +201,7 @@ class BackfillStream(Stream): | |||
|     def __init__(self, hs): | ||||
|         store = hs.get_datastore() | ||||
|         self.current_token = store.get_current_backfill_token  # type: ignore | ||||
|         self.update_function = store.get_all_new_backfill_event_rows  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore | ||||
| 
 | ||||
|         super(BackfillStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -190,8 +227,15 @@ class PresenceStream(Stream): | |||
|         store = hs.get_datastore() | ||||
|         presence_handler = hs.get_presence_handler() | ||||
| 
 | ||||
|         self._is_worker = hs.config.worker_app is not None | ||||
| 
 | ||||
|         self.current_token = store.get_current_presence_token  # type: ignore | ||||
|         self.update_function = presence_handler.get_all_presence_updates  # type: ignore | ||||
| 
 | ||||
|         if hs.config.worker_app is None: | ||||
|             self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore | ||||
|         else: | ||||
|             # Query master process | ||||
|             self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore | ||||
| 
 | ||||
|         super(PresenceStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -208,7 +252,12 @@ class TypingStream(Stream): | |||
|         typing_handler = hs.get_typing_handler() | ||||
| 
 | ||||
|         self.current_token = typing_handler.get_current_token  # type: ignore | ||||
|         self.update_function = typing_handler.get_all_typing_updates  # type: ignore | ||||
| 
 | ||||
|         if hs.config.worker_app is None: | ||||
|             self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore | ||||
|         else: | ||||
|             # Query master process | ||||
|             self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore | ||||
| 
 | ||||
|         super(TypingStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -232,7 +281,7 @@ class ReceiptsStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_max_receipt_stream_id  # type: ignore | ||||
|         self.update_function = store.get_all_updated_receipts  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore | ||||
| 
 | ||||
|         super(ReceiptsStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -256,7 +305,13 @@ class PushRulesStream(Stream): | |||
| 
 | ||||
|     async def update_function(self, from_token, to_token, limit): | ||||
|         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) | ||||
|         return [(row[0], row[2]) for row in rows] | ||||
| 
 | ||||
|         limited = False | ||||
|         if len(rows) == limit: | ||||
|             to_token = rows[-1][0] | ||||
|             limited = True | ||||
| 
 | ||||
|         return [(row[0], (row[2],)) for row in rows], to_token, limited | ||||
| 
 | ||||
| 
 | ||||
| class PushersStream(Stream): | ||||
|  | @ -275,7 +330,7 @@ class PushersStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_pushers_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_updated_pushers_rows  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore | ||||
| 
 | ||||
|         super(PushersStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -307,7 +362,7 @@ class CachesStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_cache_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_updated_caches  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore | ||||
| 
 | ||||
|         super(CachesStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -333,7 +388,7 @@ class PublicRoomsStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_current_public_room_stream_id  # type: ignore | ||||
|         self.update_function = store.get_all_new_public_rooms  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore | ||||
| 
 | ||||
|         super(PublicRoomsStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -354,7 +409,7 @@ class DeviceListsStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_device_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore | ||||
| 
 | ||||
|         super(DeviceListsStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -372,7 +427,7 @@ class ToDeviceStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_to_device_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_new_device_messages  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore | ||||
| 
 | ||||
|         super(ToDeviceStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -392,7 +447,7 @@ class TagAccountDataStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_max_account_data_stream_id  # type: ignore | ||||
|         self.update_function = store.get_all_updated_tags  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore | ||||
| 
 | ||||
|         super(TagAccountDataStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -412,10 +467,11 @@ class AccountDataStream(Stream): | |||
|         self.store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(self._update_function)  # type: ignore | ||||
| 
 | ||||
|         super(AccountDataStream, self).__init__(hs) | ||||
| 
 | ||||
|     async def update_function(self, from_token, to_token, limit): | ||||
|     async def _update_function(self, from_token, to_token, limit): | ||||
|         global_results, room_results = await self.store.get_all_updated_account_data( | ||||
|             from_token, from_token, to_token, limit | ||||
|         ) | ||||
|  | @ -442,7 +498,7 @@ class GroupServerStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_group_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_groups_changes  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore | ||||
| 
 | ||||
|         super(GroupServerStream, self).__init__(hs) | ||||
| 
 | ||||
|  | @ -460,6 +516,6 @@ class UserSignatureStream(Stream): | |||
|         store = hs.get_datastore() | ||||
| 
 | ||||
|         self.current_token = store.get_device_stream_token  # type: ignore | ||||
|         self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore | ||||
| 
 | ||||
|         super(UserSignatureStream, self).__init__(hs) | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ from typing import Tuple, Type | |||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from ._base import Stream | ||||
| from ._base import Stream, db_query_to_update_function | ||||
| 
 | ||||
| 
 | ||||
| """Handling of the 'events' replication stream | ||||
|  | @ -117,10 +117,11 @@ class EventsStream(Stream): | |||
|     def __init__(self, hs): | ||||
|         self._store = hs.get_datastore() | ||||
|         self.current_token = self._store.get_current_events_token  # type: ignore | ||||
|         self.update_function = db_query_to_update_function(self._update_function)  # type: ignore | ||||
| 
 | ||||
|         super(EventsStream, self).__init__(hs) | ||||
| 
 | ||||
|     async def update_function(self, from_token, current_token, limit=None): | ||||
|     async def _update_function(self, from_token, current_token, limit=None): | ||||
|         event_rows = await self._store.get_all_new_forward_event_rows( | ||||
|             from_token, current_token, limit | ||||
|         ) | ||||
|  |  | |||
|  | @ -15,7 +15,9 @@ | |||
| # limitations under the License. | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from ._base import Stream | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function | ||||
| 
 | ||||
| 
 | ||||
| class FederationStream(Stream): | ||||
|  | @ -33,11 +35,18 @@ class FederationStream(Stream): | |||
| 
 | ||||
|     NAME = "federation" | ||||
|     ROW_TYPE = FederationStreamRow | ||||
|     _QUERY_MASTER = True | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         federation_sender = hs.get_federation_sender() | ||||
| 
 | ||||
|         self.current_token = federation_sender.get_current_token  # type: ignore | ||||
|         self.update_function = federation_sender.get_replication_rows  # type: ignore | ||||
|         # Not all synapse instances will have a federation sender instance, | ||||
|         # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, | ||||
|         # so we stub the stream out when that is the case. | ||||
|         if hs.config.worker_app is None or hs.should_send_federation(): | ||||
|             federation_sender = hs.get_federation_sender() | ||||
|             self.current_token = federation_sender.get_current_token  # type: ignore | ||||
|             self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore | ||||
|         else: | ||||
|             self.current_token = lambda: 0  # type: ignore | ||||
|             self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore | ||||
| 
 | ||||
|         super(FederationStream, self).__init__(hs) | ||||
|  |  | |||
|  | @ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient | |||
| from synapse.notifier import Notifier | ||||
| from synapse.push.action_generator import ActionGenerator | ||||
| from synapse.push.pusherpool import PusherPool | ||||
| from synapse.replication.tcp.resource import ReplicationStreamer | ||||
| from synapse.rest.media.v1.media_repository import ( | ||||
|     MediaRepository, | ||||
|     MediaRepositoryResource, | ||||
|  | @ -199,6 +200,7 @@ class HomeServer(object): | |||
|         "saml_handler", | ||||
|         "event_client_serializer", | ||||
|         "storage", | ||||
|         "replication_streamer", | ||||
|     ] | ||||
| 
 | ||||
|     REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] | ||||
|  | @ -536,6 +538,9 @@ class HomeServer(object): | |||
|     def build_storage(self) -> Storage: | ||||
|         return Storage(self, self.datastores) | ||||
| 
 | ||||
|     def build_replication_streamer(self) -> ReplicationStreamer: | ||||
|         return ReplicationStreamer(self) | ||||
| 
 | ||||
|     def remove_pusher(self, app_id, push_key, user_id): | ||||
|         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) | ||||
| 
 | ||||
|  |  | |||
|  | @ -32,7 +32,29 @@ logger = logging.getLogger(__name__) | |||
| CURRENT_STATE_CACHE_NAME = "cs_cache_fake" | ||||
| 
 | ||||
| 
 | ||||
| class CacheInvalidationStore(SQLBaseStore): | ||||
| class CacheInvalidationWorkerStore(SQLBaseStore): | ||||
|     def get_all_updated_caches(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_updated_caches_txn(txn): | ||||
|             # We purposefully don't bound by the current token, as we want to | ||||
|             # send across cache invalidations as quickly as possible. Cache | ||||
|             # invalidations are idempotent, so duplicates are fine. | ||||
|             sql = ( | ||||
|                 "SELECT stream_id, cache_func, keys, invalidation_ts" | ||||
|                 " FROM cache_invalidation_stream" | ||||
|                 " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_updated_caches", get_all_updated_caches_txn | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class CacheInvalidationStore(CacheInvalidationWorkerStore): | ||||
|     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): | ||||
|         """Invalidates the cache and adds it to the cache stream so slaves | ||||
|         will know to invalidate their caches. | ||||
|  | @ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore): | |||
|                 }, | ||||
|             ) | ||||
| 
 | ||||
|     def get_all_updated_caches(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_updated_caches_txn(txn): | ||||
|             # We purposefully don't bound by the current token, as we want to | ||||
|             # send across cache invalidations as quickly as possible. Cache | ||||
|             # invalidations are idempotent, so duplicates are fine. | ||||
|             sql = ( | ||||
|                 "SELECT stream_id, cache_func, keys, invalidation_ts" | ||||
|                 " FROM cache_invalidation_stream" | ||||
|                 " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_updated_caches", get_all_updated_caches_txn | ||||
|         ) | ||||
| 
 | ||||
|     def get_cache_stream_token(self): | ||||
|         if self._cache_id_gen: | ||||
|             return self._cache_id_gen.get_current_token() | ||||
|  |  | |||
|  | @ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
|             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn | ||||
|         ) | ||||
| 
 | ||||
|     def get_all_new_device_messages(self, last_pos, current_pos, limit): | ||||
|         """ | ||||
|         Args: | ||||
|             last_pos(int): | ||||
|             current_pos(int): | ||||
|             limit(int): | ||||
|         Returns: | ||||
|             A deferred list of rows from the device inbox | ||||
|         """ | ||||
|         if last_pos == current_pos: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_device_messages_txn(txn): | ||||
|             # We limit like this as we might have multiple rows per stream_id, and | ||||
|             # we want to make sure we always get all entries for any stream_id | ||||
|             # we return. | ||||
|             upper_pos = min(current_pos, last_pos + limit) | ||||
|             sql = ( | ||||
|                 "SELECT max(stream_id), user_id" | ||||
|                 " FROM device_inbox" | ||||
|                 " WHERE ? < stream_id AND stream_id <= ?" | ||||
|                 " GROUP BY user_id" | ||||
|             ) | ||||
|             txn.execute(sql, (last_pos, upper_pos)) | ||||
|             rows = txn.fetchall() | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT max(stream_id), destination" | ||||
|                 " FROM device_federation_outbox" | ||||
|                 " WHERE ? < stream_id AND stream_id <= ?" | ||||
|                 " GROUP BY destination" | ||||
|             ) | ||||
|             txn.execute(sql, (last_pos, upper_pos)) | ||||
|             rows.extend(txn) | ||||
| 
 | ||||
|             # Order by ascending stream ordering | ||||
|             rows.sort() | ||||
| 
 | ||||
|             return rows | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_device_messages", get_all_new_device_messages_txn | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | ||||
|     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | ||||
|  | @ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) | |||
|                 rows.append((user_id, device_id, stream_id, message_json)) | ||||
| 
 | ||||
|         txn.executemany(sql, rows) | ||||
| 
 | ||||
|     def get_all_new_device_messages(self, last_pos, current_pos, limit): | ||||
|         """ | ||||
|         Args: | ||||
|             last_pos(int): | ||||
|             current_pos(int): | ||||
|             limit(int): | ||||
|         Returns: | ||||
|             A deferred list of rows from the device inbox | ||||
|         """ | ||||
|         if last_pos == current_pos: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_device_messages_txn(txn): | ||||
|             # We limit like this as we might have multiple rows per stream_id, and | ||||
|             # we want to make sure we always get all entries for any stream_id | ||||
|             # we return. | ||||
|             upper_pos = min(current_pos, last_pos + limit) | ||||
|             sql = ( | ||||
|                 "SELECT max(stream_id), user_id" | ||||
|                 " FROM device_inbox" | ||||
|                 " WHERE ? < stream_id AND stream_id <= ?" | ||||
|                 " GROUP BY user_id" | ||||
|             ) | ||||
|             txn.execute(sql, (last_pos, upper_pos)) | ||||
|             rows = txn.fetchall() | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT max(stream_id), destination" | ||||
|                 " FROM device_federation_outbox" | ||||
|                 " WHERE ? < stream_id AND stream_id <= ?" | ||||
|                 " GROUP BY destination" | ||||
|             ) | ||||
|             txn.execute(sql, (last_pos, upper_pos)) | ||||
|             rows.extend(txn) | ||||
| 
 | ||||
|             # Order by ascending stream ordering | ||||
|             rows.sort() | ||||
| 
 | ||||
|             return rows | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_device_messages", get_all_new_device_messages_txn | ||||
|         ) | ||||
|  |  | |||
|  | @ -1267,104 +1267,6 @@ class EventsStore( | |||
|         ret = yield self.db.runInteraction("count_daily_active_rooms", _count) | ||||
|         return ret | ||||
| 
 | ||||
|     def get_current_backfill_token(self): | ||||
|         """The current minimum token that backfilled events have reached""" | ||||
|         return -self._backfill_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_current_events_token(self): | ||||
|         """The current maximum token that events have reached""" | ||||
|         return self._stream_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_all_new_forward_event_rows(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_forward_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < stream_ordering AND stream_ordering <= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, current_id, limit)) | ||||
|             new_event_updates = txn.fetchall() | ||||
| 
 | ||||
|             if len(new_event_updates) == limit: | ||||
|                 upper_bound = new_event_updates[-1][0] | ||||
|             else: | ||||
|                 upper_bound = current_id | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < event_stream_ordering" | ||||
|                 " AND event_stream_ordering <= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, upper_bound)) | ||||
|             new_event_updates.extend(txn) | ||||
| 
 | ||||
|             return new_event_updates | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_forward_event_rows", get_all_new_forward_event_rows | ||||
|         ) | ||||
| 
 | ||||
|     def get_all_new_backfill_event_rows(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_backfill_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > stream_ordering AND stream_ordering >= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (-last_id, -current_id, limit)) | ||||
|             new_event_updates = txn.fetchall() | ||||
| 
 | ||||
|             if len(new_event_updates) == limit: | ||||
|                 upper_bound = new_event_updates[-1][0] | ||||
|             else: | ||||
|                 upper_bound = current_id | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > event_stream_ordering" | ||||
|                 " AND event_stream_ordering >= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|             ) | ||||
|             txn.execute(sql, (-last_id, -upper_bound)) | ||||
|             new_event_updates.extend(txn.fetchall()) | ||||
| 
 | ||||
|             return new_event_updates | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows | ||||
|         ) | ||||
| 
 | ||||
|     @cached(num_args=5, max_entries=10) | ||||
|     def get_all_new_events( | ||||
|         self, | ||||
|  | @ -1850,22 +1752,6 @@ class EventsStore( | |||
| 
 | ||||
|         return (int(res["topological_ordering"]), int(res["stream_ordering"])) | ||||
| 
 | ||||
|     def get_all_updated_current_state_deltas(self, from_token, to_token, limit): | ||||
|         def get_all_updated_current_state_deltas_txn(txn): | ||||
|             sql = """ | ||||
|                 SELECT stream_id, room_id, type, state_key, event_id | ||||
|                 FROM current_state_delta_stream | ||||
|                 WHERE ? < stream_id AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC LIMIT ? | ||||
|             """ | ||||
|             txn.execute(sql, (from_token, to_token, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_updated_current_state_deltas", | ||||
|             get_all_updated_current_state_deltas_txn, | ||||
|         ) | ||||
| 
 | ||||
|     def insert_labels_for_event_txn( | ||||
|         self, txn, event_id, labels, room_id, topological_ordering | ||||
|     ): | ||||
|  |  | |||
|  | @ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore): | |||
|         complexity_v1 = round(state_events / 500, 2) | ||||
| 
 | ||||
|         return {"v1": complexity_v1} | ||||
| 
 | ||||
|     def get_current_backfill_token(self): | ||||
|         """The current minimum token that backfilled events have reached""" | ||||
|         return -self._backfill_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_current_events_token(self): | ||||
|         """The current maximum token that events have reached""" | ||||
|         return self._stream_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_all_new_forward_event_rows(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_forward_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < stream_ordering AND stream_ordering <= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, current_id, limit)) | ||||
|             new_event_updates = txn.fetchall() | ||||
| 
 | ||||
|             if len(new_event_updates) == limit: | ||||
|                 upper_bound = new_event_updates[-1][0] | ||||
|             else: | ||||
|                 upper_bound = current_id | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? < event_stream_ordering" | ||||
|                 " AND event_stream_ordering <= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|             ) | ||||
|             txn.execute(sql, (last_id, upper_bound)) | ||||
|             new_event_updates.extend(txn) | ||||
| 
 | ||||
|             return new_event_updates | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_forward_event_rows", get_all_new_forward_event_rows | ||||
|         ) | ||||
| 
 | ||||
|     def get_all_new_backfill_event_rows(self, last_id, current_id, limit): | ||||
|         if last_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         def get_all_new_backfill_event_rows(txn): | ||||
|             sql = ( | ||||
|                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > stream_ordering AND stream_ordering >= ?" | ||||
|                 " ORDER BY stream_ordering ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (-last_id, -current_id, limit)) | ||||
|             new_event_updates = txn.fetchall() | ||||
| 
 | ||||
|             if len(new_event_updates) == limit: | ||||
|                 upper_bound = new_event_updates[-1][0] | ||||
|             else: | ||||
|                 upper_bound = current_id | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," | ||||
|                 " state_key, redacts, relates_to_id" | ||||
|                 " FROM events AS e" | ||||
|                 " INNER JOIN ex_outlier_stream USING (event_id)" | ||||
|                 " LEFT JOIN redactions USING (event_id)" | ||||
|                 " LEFT JOIN state_events USING (event_id)" | ||||
|                 " LEFT JOIN event_relations USING (event_id)" | ||||
|                 " WHERE ? > event_stream_ordering" | ||||
|                 " AND event_stream_ordering >= ?" | ||||
|                 " ORDER BY event_stream_ordering DESC" | ||||
|             ) | ||||
|             txn.execute(sql, (-last_id, -upper_bound)) | ||||
|             new_event_updates.extend(txn.fetchall()) | ||||
| 
 | ||||
|             return new_event_updates | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows | ||||
|         ) | ||||
| 
 | ||||
|     def get_all_updated_current_state_deltas(self, from_token, to_token, limit): | ||||
|         def get_all_updated_current_state_deltas_txn(txn): | ||||
|             sql = """ | ||||
|                 SELECT stream_id, room_id, type, state_key, event_id | ||||
|                 FROM current_state_delta_stream | ||||
|                 WHERE ? < stream_id AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC LIMIT ? | ||||
|             """ | ||||
|             txn.execute(sql, (from_token, to_token, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_updated_current_state_deltas", | ||||
|             get_all_updated_current_state_deltas_txn, | ||||
|         ) | ||||
|  |  | |||
|  | @ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return total_media_quarantined | ||||
| 
 | ||||
|     def get_all_new_public_rooms(self, prev_id, current_id, limit): | ||||
|         def get_all_new_public_rooms(txn): | ||||
|             sql = """ | ||||
|                 SELECT stream_id, room_id, visibility, appservice_id, network_id | ||||
|                 FROM public_room_list_stream | ||||
|                 WHERE stream_id > ? AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC | ||||
|                 LIMIT ? | ||||
|             """ | ||||
| 
 | ||||
|             txn.execute(sql, (prev_id, current_id, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         if prev_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_public_rooms", get_all_new_public_rooms | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class RoomBackgroundUpdateStore(SQLBaseStore): | ||||
|     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" | ||||
|  | @ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
|     def get_current_public_room_stream_id(self): | ||||
|         return self._public_room_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_all_new_public_rooms(self, prev_id, current_id, limit): | ||||
|         def get_all_new_public_rooms(txn): | ||||
|             sql = """ | ||||
|                 SELECT stream_id, room_id, visibility, appservice_id, network_id | ||||
|                 FROM public_room_list_stream | ||||
|                 WHERE stream_id > ? AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC | ||||
|                 LIMIT ? | ||||
|             """ | ||||
| 
 | ||||
|             txn.execute(sql, (prev_id, current_id, limit)) | ||||
|             return txn.fetchall() | ||||
| 
 | ||||
|         if prev_id == current_id: | ||||
|             return defer.succeed([]) | ||||
| 
 | ||||
|         return self.db.runInteraction( | ||||
|             "get_all_new_public_rooms", get_all_new_public_rooms | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def block_room(self, room_id, user_id): | ||||
|         """Marks the room as blocked. Can be called multiple times. | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from mock import Mock | ||||
| 
 | ||||
| from synapse.replication.tcp.commands import ReplicateCommand | ||||
|  | @ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): | |||
|         # build a replication server | ||||
|         server_factory = ReplicationStreamProtocolFactory(self.hs) | ||||
|         self.streamer = server_factory.streamer | ||||
|         server = server_factory.buildProtocol(None) | ||||
|         self.server = server_factory.buildProtocol(None) | ||||
| 
 | ||||
|         # build a replication client, with a dummy handler | ||||
|         handler_factory = Mock() | ||||
|         self.test_handler = TestReplicationClientHandler() | ||||
|         self.test_handler.factory = handler_factory | ||||
|         self.test_handler = Mock(wraps=TestReplicationClientHandler()) | ||||
|         self.client = ClientReplicationStreamProtocol( | ||||
|             "client", "test", clock, self.test_handler | ||||
|             hs, "client", "test", clock, self.test_handler, | ||||
|         ) | ||||
| 
 | ||||
|         # wire them together | ||||
|         self.client.makeConnection(FakeTransport(server, reactor)) | ||||
|         server.makeConnection(FakeTransport(self.client, reactor)) | ||||
|         self._client_transport = None | ||||
|         self._server_transport = None | ||||
| 
 | ||||
|     def reconnect(self): | ||||
|         if self._client_transport: | ||||
|             self.client.close() | ||||
| 
 | ||||
|         if self._server_transport: | ||||
|             self.server.close() | ||||
| 
 | ||||
|         self._client_transport = FakeTransport(self.server, self.reactor) | ||||
|         self.client.makeConnection(self._client_transport) | ||||
| 
 | ||||
|         self._server_transport = FakeTransport(self.client, self.reactor) | ||||
|         self.server.makeConnection(self._server_transport) | ||||
| 
 | ||||
|     def disconnect(self): | ||||
|         if self._client_transport: | ||||
|             self._client_transport = None | ||||
|             self.client.close() | ||||
| 
 | ||||
|         if self._server_transport: | ||||
|             self._server_transport = None | ||||
|             self.server.close() | ||||
| 
 | ||||
|     def replicate(self): | ||||
|         """Tell the master side of replication that something has happened, and then | ||||
|  | @ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): | |||
|         self.streamer.on_notifier_poke() | ||||
|         self.pump(0.1) | ||||
| 
 | ||||
|     def replicate_stream(self, stream, token="NOW"): | ||||
|     def replicate_stream(self): | ||||
|         """Make the client end a REPLICATE command to set up a subscription to a stream""" | ||||
|         self.client.send_command(ReplicateCommand(stream, token)) | ||||
|         self.client.send_command(ReplicateCommand()) | ||||
| 
 | ||||
| 
 | ||||
| class TestReplicationClientHandler(object): | ||||
|     """Drop-in for ReplicationClientHandler which just collects RDATA rows""" | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.received_rdata_rows = [] | ||||
|         self.streams = set() | ||||
|         self._received_rdata_rows = [] | ||||
| 
 | ||||
|     def get_streams_to_replicate(self): | ||||
|         return {} | ||||
|         positions = {s: 0 for s in self.streams} | ||||
|         for stream, token, _ in self._received_rdata_rows: | ||||
|             if stream in self.streams: | ||||
|                 positions[stream] = max(token, positions.get(stream, 0)) | ||||
|         return positions | ||||
| 
 | ||||
|     def get_currently_syncing_users(self): | ||||
|         return [] | ||||
|  | @ -73,6 +97,9 @@ class TestReplicationClientHandler(object): | |||
|     def finished_connecting(self): | ||||
|         pass | ||||
| 
 | ||||
|     async def on_position(self, stream_name, token): | ||||
|         """Called when we get new position data.""" | ||||
| 
 | ||||
|     async def on_rdata(self, stream_name, token, rows): | ||||
|         for r in rows: | ||||
|             self.received_rdata_rows.append((stream_name, token, r)) | ||||
|             self._received_rdata_rows.append((stream_name, token, r)) | ||||
|  |  | |||
|  | @ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream | |||
| from tests.replication.tcp.streams._base import BaseStreamTestCase | ||||
| 
 | ||||
| USER_ID = "@feeling:blue" | ||||
| ROOM_ID = "!room:blue" | ||||
| EVENT_ID = "$event:blue" | ||||
| 
 | ||||
| 
 | ||||
| class ReceiptsStreamTestCase(BaseStreamTestCase): | ||||
|     def test_receipt(self): | ||||
|         self.reconnect() | ||||
| 
 | ||||
|         # make the client subscribe to the receipts stream | ||||
|         self.replicate_stream("receipts", "NOW") | ||||
|         self.replicate_stream() | ||||
|         self.test_handler.streams.add("receipts") | ||||
| 
 | ||||
|         # tell the master to send a new receipt | ||||
|         self.get_success( | ||||
|             self.hs.get_datastore().insert_receipt( | ||||
|                 ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} | ||||
|                 "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} | ||||
|             ) | ||||
|         ) | ||||
|         self.replicate() | ||||
| 
 | ||||
|         # there should be one RDATA command | ||||
|         rdata_rows = self.test_handler.received_rdata_rows | ||||
|         self.test_handler.on_rdata.assert_called_once() | ||||
|         stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] | ||||
|         self.assertEqual(stream_name, "receipts") | ||||
|         self.assertEqual(1, len(rdata_rows)) | ||||
|         self.assertEqual(rdata_rows[0][0], "receipts") | ||||
|         row = rdata_rows[0][2]  # type: ReceiptsStream.ReceiptsStreamRow | ||||
|         self.assertEqual(ROOM_ID, row.room_id) | ||||
|         row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow | ||||
|         self.assertEqual("!room:blue", row.room_id) | ||||
|         self.assertEqual("m.read", row.receipt_type) | ||||
|         self.assertEqual(USER_ID, row.user_id) | ||||
|         self.assertEqual(EVENT_ID, row.event_id) | ||||
|         self.assertEqual("$event:blue", row.event_id) | ||||
|         self.assertEqual({"a": 1}, row.data) | ||||
| 
 | ||||
|         # Now let's disconnect and insert some data. | ||||
|         self.disconnect() | ||||
| 
 | ||||
|         self.test_handler.on_rdata.reset_mock() | ||||
| 
 | ||||
|         self.get_success( | ||||
|             self.hs.get_datastore().insert_receipt( | ||||
|                 "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} | ||||
|             ) | ||||
|         ) | ||||
|         self.replicate() | ||||
| 
 | ||||
|         # Nothing should have happened as we are disconnected | ||||
|         self.test_handler.on_rdata.assert_not_called() | ||||
| 
 | ||||
|         self.reconnect() | ||||
|         self.pump(0.1) | ||||
| 
 | ||||
|         # We should now have caught up and get the missing data | ||||
|         self.test_handler.on_rdata.assert_called_once() | ||||
|         stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] | ||||
|         self.assertEqual(stream_name, "receipts") | ||||
|         self.assertEqual(token, 3) | ||||
|         self.assertEqual(1, len(rdata_rows)) | ||||
| 
 | ||||
|         row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow | ||||
|         self.assertEqual("!room2:blue", row.room_id) | ||||
|         self.assertEqual("m.read", row.receipt_type) | ||||
|         self.assertEqual(USER_ID, row.user_id) | ||||
|         self.assertEqual("$event2:foo", row.event_id) | ||||
|         self.assertEqual({"a": 2}, row.data) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston