Move stream catchup to workers.
							parent
							
								
									ba90596687
								
							
						
					
					
						commit
						1f83255de1
					
				|  | @ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): | |||
|         self.client_name = client_name | ||||
|         self.handler = handler | ||||
|         self.server_name = hs.config.server_name | ||||
|         self.hs = hs | ||||
|         self._clock = hs.get_clock()  # As self.clock is defined in super class | ||||
| 
 | ||||
|         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) | ||||
|  | @ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): | |||
|     def buildProtocol(self, addr): | ||||
|         logger.info("Connected to replication: %r", addr) | ||||
|         return ClientReplicationStreamProtocol( | ||||
|             self.client_name, self.server_name, self._clock, self.handler | ||||
|             self.hs, self.client_name, self.server_name, self._clock, self.handler, | ||||
|         ) | ||||
| 
 | ||||
|     def clientConnectionLost(self, connector, reason): | ||||
|  |  | |||
|  | @ -82,7 +82,8 @@ from synapse.replication.tcp.commands import ( | |||
|     SyncCommand, | ||||
|     UserSyncCommand, | ||||
| ) | ||||
| from synapse.replication.tcp.streams import STREAMS_MAP | ||||
| from synapse.replication.tcp.streams import STREAMS_MAP, Stream | ||||
| from synapse.server import HomeServer | ||||
| from synapse.types import Collection | ||||
| from synapse.util import Clock | ||||
| from synapse.util.stringutils import random_string | ||||
|  | @ -414,9 +415,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         # The streams the client has subscribed to and is up to date with | ||||
|         self.replication_streams = set()  # type: Set[str] | ||||
| 
 | ||||
|         # The streams the client is currently subscribing to. | ||||
|         self.connecting_streams = set()  # type:  Set[str] | ||||
| 
 | ||||
|         # Map from stream name to list of updates to send once we've finished | ||||
|         # subscribing the client to the stream. | ||||
|         self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]] | ||||
|  | @ -482,67 +480,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         are queued and sent once we've sent down any missed updates. | ||||
|         """ | ||||
|         self.replication_streams.discard(stream_name) | ||||
|         self.connecting_streams.add(stream_name) | ||||
| 
 | ||||
|         try: | ||||
|             limited = True | ||||
|             while limited: | ||||
|                 # Get missing updates | ||||
|                 ( | ||||
|                     updates, | ||||
|                     current_token, | ||||
|                     limited, | ||||
|                 ) = await self.streamer.get_stream_updates(stream_name, token) | ||||
| 
 | ||||
|                 # Send all the missing updates | ||||
|                 for update in updates: | ||||
|                     token, row = update[0], update[1] | ||||
|                     self.send_command(RdataCommand(stream_name, token, row)) | ||||
|             # Get current stream position. | ||||
|             current_token = self.streamer.get_stream_token(stream_name) | ||||
| 
 | ||||
|             # We send a POSITION command to ensure that they have an up to | ||||
|             # date token (especially useful if we didn't send any updates | ||||
|             # above) | ||||
|             self.send_command(PositionCommand(stream_name, current_token)) | ||||
| 
 | ||||
|             # Now we can send any updates that came in while we were subscribing | ||||
|             pending_rdata = self.pending_rdata.pop(stream_name, []) | ||||
|             updates = [] | ||||
|             for token, update in pending_rdata: | ||||
|                 # If the token is null, it is part of a batch update. Batches | ||||
|                 # are multiple updates that share a single token. To denote | ||||
|                 # this, the token is set to None for all tokens in the batch | ||||
|                 # except for the last. If we find a None token, we keep looking | ||||
|                 # through tokens until we find one that is not None and then | ||||
|                 # process all previous updates in the batch as if they had the | ||||
|                 # final token. | ||||
|                 if token is None: | ||||
|                     # Store this update as part of a batch | ||||
|                     updates.append(update) | ||||
|                     continue | ||||
| 
 | ||||
|                 if token <= current_token: | ||||
|                     # This update or batch of updates is older than | ||||
|                     # current_token, dismiss it | ||||
|                     updates = [] | ||||
|                     continue | ||||
| 
 | ||||
|                 updates.append(update) | ||||
| 
 | ||||
|                 # Send all updates that are part of this batch with the | ||||
|                 # found token | ||||
|                 for update in updates: | ||||
|                     self.send_command(RdataCommand(stream_name, token, update)) | ||||
| 
 | ||||
|                 # Clear stored updates | ||||
|                 updates = [] | ||||
| 
 | ||||
|             # They're now fully subscribed | ||||
|             self.replication_streams.add(stream_name) | ||||
|         except Exception as e: | ||||
|             logger.exception("[%s] Failed to handle REPLICATE command", self.id()) | ||||
|             self.send_error("failed to handle replicate: %r", e) | ||||
|         finally: | ||||
|             self.connecting_streams.discard(stream_name) | ||||
| 
 | ||||
|     def stream_update(self, stream_name, token, data): | ||||
|         """Called when a new update is available to stream to clients. | ||||
|  | @ -552,10 +504,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         if stream_name in self.replication_streams: | ||||
|             # The client is subscribed to the stream | ||||
|             self.send_command(RdataCommand(stream_name, token, data)) | ||||
|         elif stream_name in self.connecting_streams: | ||||
|             # The client is being subscribed to the stream | ||||
|             logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) | ||||
|             self.pending_rdata.setdefault(stream_name, []).append((token, data)) | ||||
|         else: | ||||
|             # The client isn't subscribed | ||||
|             logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) | ||||
|  | @ -642,6 +590,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         hs: HomeServer, | ||||
|         client_name: str, | ||||
|         server_name: str, | ||||
|         clock: Clock, | ||||
|  | @ -653,6 +602,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         self.server_name = server_name | ||||
|         self.handler = handler | ||||
| 
 | ||||
|         self.streams = { | ||||
|             stream.NAME: stream(hs) for stream in STREAMS_MAP.values() | ||||
|         }  # type: Dict[str, Stream] | ||||
| 
 | ||||
|         # Set of stream names that have been subscribe to, but haven't yet | ||||
|         # caught up with. This is used to track when the client has been fully | ||||
|         # connected to the remote. | ||||
|  | @ -660,7 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
| 
 | ||||
|         # Map of stream to batched updates. See RdataCommand for info on how | ||||
|         # batching works. | ||||
|         self.pending_batches = {}  # type: Dict[str, Any] | ||||
|         self.pending_batches = {}  # type: Dict[str, List[Any]] | ||||
| 
 | ||||
|     def connectionMade(self): | ||||
|         self.send_command(NameCommand(self.client_name)) | ||||
|  | @ -701,7 +654,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             ) | ||||
|             raise | ||||
| 
 | ||||
|         if cmd.token is None: | ||||
|         if cmd.token is None or stream_name in self.streams_connecting: | ||||
|             # I.e. this is part of a batch of updates for this stream. Batch | ||||
|             # until we get an update for the stream with a non None token | ||||
|             self.pending_batches.setdefault(stream_name, []).append(row) | ||||
|  | @ -711,14 +664,46 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             rows.append(row) | ||||
|             await self.handler.on_rdata(stream_name, cmd.token, rows) | ||||
| 
 | ||||
|     async def on_POSITION(self, cmd): | ||||
|     async def on_POSITION(self, cmd: PositionCommand): | ||||
|         stream = self.streams.get(cmd.stream_name) | ||||
|         if not stream: | ||||
|             logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) | ||||
|             return | ||||
| 
 | ||||
|         # Find where we previously streamed up to. | ||||
|         current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) | ||||
|         if current_token is None: | ||||
|             logger.warning( | ||||
|                 "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name | ||||
|             ) | ||||
|             return | ||||
| 
 | ||||
|         # Fetch all updates between then and now. | ||||
|         limited = True | ||||
|         while limited: | ||||
|             updates, current_token, limited = await stream.get_updates_since( | ||||
|                 current_token, cmd.token | ||||
|             ) | ||||
|             if updates: | ||||
|                 await self.handler.on_rdata( | ||||
|                     cmd.stream_name, | ||||
|                     current_token, | ||||
|                     [stream.parse_row(update[1]) for update in updates], | ||||
|                 ) | ||||
| 
 | ||||
|         # We've now caught up to position sent to us, notify handler. | ||||
|         await self.handler.on_position(cmd.stream_name, cmd.token) | ||||
| 
 | ||||
|         # When we get a `POSITION` command it means we've finished getting | ||||
|         # missing updates for the given stream, and are now up to date. | ||||
|         self.streams_connecting.discard(cmd.stream_name) | ||||
|         if not self.streams_connecting: | ||||
|             self.handler.finished_connecting() | ||||
| 
 | ||||
|         await self.handler.on_position(cmd.stream_name, cmd.token) | ||||
|         # Handle any RDATA that came in while we were catching up. | ||||
|         rows = self.pending_batches.pop(cmd.stream_name, []) | ||||
|         if rows: | ||||
|             await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) | ||||
| 
 | ||||
|     async def on_SYNC(self, cmd): | ||||
|         self.handler.on_sync(cmd.data) | ||||
|  |  | |||
|  | @ -227,8 +227,7 @@ class ReplicationStreamer(object): | |||
|             self.pending_updates = False | ||||
|             self.is_looping = False | ||||
| 
 | ||||
|     @measure_func("repl.get_stream_updates") | ||||
|     async def get_stream_updates(self, stream_name, token): | ||||
|     def get_stream_token(self, stream_name): | ||||
|         """For a given stream get all updates since token. This is called when | ||||
|         a client first subscribes to a stream. | ||||
|         """ | ||||
|  | @ -236,7 +235,7 @@ class ReplicationStreamer(object): | |||
|         if not stream: | ||||
|             raise Exception("unknown stream %s", stream_name) | ||||
| 
 | ||||
|         return await stream.get_updates_since(token, stream.current_token()) | ||||
|         return stream.current_token() | ||||
| 
 | ||||
|     @measure_func("repl.federation_ack") | ||||
|     def federation_ack(self, token): | ||||
|  |  | |||
|  | @ -27,7 +27,8 @@ Each stream is defined by the following information: | |||
| 
 | ||||
| from typing import Dict, Type | ||||
| 
 | ||||
| from . import _base, events, federation | ||||
| from synapse.replication.tcp.streams import _base, events, federation | ||||
| from synapse.replication.tcp.streams._base import Stream | ||||
| 
 | ||||
| STREAMS_MAP = { | ||||
|     stream.NAME: stream | ||||
|  | @ -50,3 +51,6 @@ STREAMS_MAP = { | |||
|         _base.UserSignatureStream, | ||||
|     ) | ||||
| }  # type: Dict[str, Type[_base.Stream]] | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["Stream", "STREAMS_MAP"] | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from mock import Mock | ||||
| 
 | ||||
| from synapse.replication.tcp.commands import ReplicateCommand | ||||
|  | @ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): | |||
|         # build a replication server | ||||
|         server_factory = ReplicationStreamProtocolFactory(self.hs) | ||||
|         self.streamer = server_factory.streamer | ||||
|         server = server_factory.buildProtocol(None) | ||||
|         self.server = server_factory.buildProtocol(None) | ||||
| 
 | ||||
|         # build a replication client, with a dummy handler | ||||
|         handler_factory = Mock() | ||||
|         self.test_handler = TestReplicationClientHandler() | ||||
|         self.test_handler.factory = handler_factory | ||||
|         self.test_handler = Mock(wraps=TestReplicationClientHandler()) | ||||
|         self.client = ClientReplicationStreamProtocol( | ||||
|             "client", "test", clock, self.test_handler | ||||
|             hs, "client", "test", clock, self.test_handler, | ||||
|         ) | ||||
| 
 | ||||
|         # wire them together | ||||
|         self.client.makeConnection(FakeTransport(server, reactor)) | ||||
|         server.makeConnection(FakeTransport(self.client, reactor)) | ||||
|         self._client_transport = None | ||||
|         self._server_transport = None | ||||
| 
 | ||||
|     def reconnect(self): | ||||
|         if self._client_transport: | ||||
|             self.client.close() | ||||
| 
 | ||||
|         if self._server_transport: | ||||
|             self.server.close() | ||||
| 
 | ||||
|         self._client_transport = FakeTransport(self.server, self.reactor) | ||||
|         self.client.makeConnection(self._client_transport) | ||||
| 
 | ||||
|         self._server_transport = FakeTransport(self.client, self.reactor) | ||||
|         self.server.makeConnection(self._server_transport) | ||||
| 
 | ||||
|     def disconnect(self): | ||||
|         if self._client_transport: | ||||
|             self._client_transport = None | ||||
|             self.client.close() | ||||
| 
 | ||||
|         if self._server_transport: | ||||
|             self._server_transport = None | ||||
|             self.server.close() | ||||
| 
 | ||||
|     def replicate(self): | ||||
|         """Tell the master side of replication that something has happened, and then | ||||
|  | @ -59,10 +78,15 @@ class TestReplicationClientHandler(object): | |||
|     """Drop-in for ReplicationClientHandler which just collects RDATA rows""" | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.received_rdata_rows = [] | ||||
|         self.streams = set() | ||||
|         self._received_rdata_rows = [] | ||||
| 
 | ||||
|     def get_streams_to_replicate(self): | ||||
|         return {} | ||||
|         positions = {s: 0 for s in self.streams} | ||||
|         for stream, token, _ in self._received_rdata_rows: | ||||
|             if stream in self.streams: | ||||
|                 positions[stream] = max(token, positions.get(stream, 0)) | ||||
|         return positions | ||||
| 
 | ||||
|     def get_currently_syncing_users(self): | ||||
|         return [] | ||||
|  | @ -73,6 +97,9 @@ class TestReplicationClientHandler(object): | |||
|     def finished_connecting(self): | ||||
|         pass | ||||
| 
 | ||||
|     async def on_position(self, stream_name, token): | ||||
|         """Called when we get new position data.""" | ||||
| 
 | ||||
|     async def on_rdata(self, stream_name, token, rows): | ||||
|         for r in rows: | ||||
|             self.received_rdata_rows.append((stream_name, token, r)) | ||||
|             self._received_rdata_rows.append((stream_name, token, r)) | ||||
|  |  | |||
|  | @ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStreamRow | |||
| from tests.replication.tcp.streams._base import BaseStreamTestCase | ||||
| 
 | ||||
| USER_ID = "@feeling:blue" | ||||
| ROOM_ID = "!room:blue" | ||||
| EVENT_ID = "$event:blue" | ||||
| 
 | ||||
| 
 | ||||
| class ReceiptsStreamTestCase(BaseStreamTestCase): | ||||
|     def test_receipt(self): | ||||
|         self.reconnect() | ||||
| 
 | ||||
|         # make the client subscribe to the receipts stream | ||||
|         self.replicate_stream("receipts", "NOW") | ||||
|         self.test_handler.streams.add("receipts") | ||||
| 
 | ||||
|         # tell the master to send a new receipt | ||||
|         self.get_success( | ||||
|             self.hs.get_datastore().insert_receipt( | ||||
|                 ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} | ||||
|                 "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} | ||||
|             ) | ||||
|         ) | ||||
|         self.replicate() | ||||
| 
 | ||||
|         # there should be one RDATA command | ||||
|         rdata_rows = self.test_handler.received_rdata_rows | ||||
|         self.test_handler.on_rdata.assert_called_once() | ||||
|         stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] | ||||
|         self.assertEqual(stream_name, "receipts") | ||||
|         self.assertEqual(1, len(rdata_rows)) | ||||
|         self.assertEqual(rdata_rows[0][0], "receipts") | ||||
|         row = rdata_rows[0][2]  # type: ReceiptsStreamRow | ||||
|         self.assertEqual(ROOM_ID, row.room_id) | ||||
|         row = rdata_rows[0]  # type: ReceiptsStreamRow | ||||
|         self.assertEqual("!room:blue", row.room_id) | ||||
|         self.assertEqual("m.read", row.receipt_type) | ||||
|         self.assertEqual(USER_ID, row.user_id) | ||||
|         self.assertEqual(EVENT_ID, row.event_id) | ||||
|         self.assertEqual("$event:blue", row.event_id) | ||||
|         self.assertEqual({"a": 1}, row.data) | ||||
| 
 | ||||
|         # Now let's disconnect and insert some data. | ||||
|         self.disconnect() | ||||
| 
 | ||||
|         self.test_handler.on_rdata.reset_mock() | ||||
| 
 | ||||
|         self.get_success( | ||||
|             self.hs.get_datastore().insert_receipt( | ||||
|                 "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} | ||||
|             ) | ||||
|         ) | ||||
|         self.replicate() | ||||
| 
 | ||||
|         # Nothing should have happened as we are disconnected | ||||
|         self.test_handler.on_rdata.assert_not_called() | ||||
| 
 | ||||
|         self.reconnect() | ||||
|         self.pump(0.1) | ||||
| 
 | ||||
|         # We should now have caught up and get the missing data | ||||
|         self.test_handler.on_rdata.assert_called_once() | ||||
|         stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] | ||||
|         self.assertEqual(stream_name, "receipts") | ||||
|         self.assertEqual(token, 3) | ||||
|         self.assertEqual(1, len(rdata_rows)) | ||||
| 
 | ||||
|         row = rdata_rows[0]  # type: ReceiptsStreamRow | ||||
|         self.assertEqual("!room2:blue", row.room_id) | ||||
|         self.assertEqual("m.read", row.receipt_type) | ||||
|         self.assertEqual(USER_ID, row.user_id) | ||||
|         self.assertEqual("$event2:foo", row.event_id) | ||||
|         self.assertEqual({"a": 2}, row.data) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston