diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b3c33370a0..841c8591bf 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -13,8 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A replication client for use by synapse workers. -""" import logging from typing import Any, Callable, Dict, List, Optional, Set @@ -51,13 +49,13 @@ class ReplicationCommandHandler: """ def __init__(self, hs): - self.replication_data_handler = hs.get_replication_data_handler() - self.presence_handler = hs.get_presence_handler() + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() # Set of streams that we're currently catching up with. - self.streams_connecting = set() # type: Set[str] + self._streams_connecting = set() # type: Set[str] - self.streams = { + self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -65,23 +63,23 @@ class ReplicationCommandHandler: # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, List[Any]] + self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] + self._factory = None # type: Optional[ReplicationClientFactory] # The current connection. None if we are currently (re)connecting - self.connection = None + self._connection = None def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) + self._factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) + hs.get_reactor().connectTCP(host, port, self._factory) async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name @@ -93,13 +91,13 @@ class ReplicationCommandHandler: logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) raise - if cmd.token is None or stream_name in self.streams_connecting: + if cmd.token is None or stream_name in self._streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) + self._pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) + rows = self._pending_batches.pop(stream_name, []) rows.append(row) await self.on_rdata(stream_name, cmd.token, rows) @@ -113,23 +111,23 @@ class ReplicationCommandHandler: Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self.replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) + stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) return # We're about to go and catch up with the stream, so mark as connecting # to stop RDATA being handled at the same time. - self.streams_connecting.add(cmd.stream_name) + self._streams_connecting.add(cmd.stream_name) # We protect catching up with a linearizer in case the replicaiton # connection reconnects under us. with await self._position_linearizer.queue(cmd.stream_name): # Find where we previously streamed up to. - current_token = self.replication_data_handler.get_streams_to_replicate().get( + current_token = self._replication_data_handler.get_streams_to_replicate().get( cmd.stream_name ) if current_token is None: @@ -153,12 +151,12 @@ class ReplicationCommandHandler: ) # We've now caught up to position sent to us, notify handler. - await self.replication_data_handler.on_position(cmd.stream_name, cmd.token) + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) - self.streams_connecting.discard(cmd.stream_name) + self._streams_connecting.discard(cmd.stream_name) # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) + rows = self._pending_batches.pop(cmd.stream_name, []) if rows: await self.on_rdata(cmd.stream_name, rows[-1].token, rows) @@ -171,14 +169,14 @@ class ReplicationCommandHandler: def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) + currently syncing users. """ - return self.presence_handler.get_currently_syncing_users() + return self._presence_handler.get_currently_syncing_users() def update_connection(self, connection): """Called when a connection has been established (or lost with None). """ - self.connection = connection + self._connection = connection def finished_connecting(self): """Called when we have successfully subscribed and caught up to all @@ -189,15 +187,15 @@ class ReplicationCommandHandler: # We don't reset the delay any earlier as otherwise if there is a # problem during start up we'll end up tight looping connecting to the # server. - if self.factory: - self.factory.resetDelay() + if self._factory: + self._factory.resetDelay() def send_command(self, cmd: Command): """Send a command to master (when we get establish a connection if we don't have one already.) """ - if self.connection: - self.connection.send_command(cmd) + if self._connection: + self._connection.send_command(cmd) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 14be64b3fd..8902a5ab69 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -57,8 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) - self.replication_handler.store = self.slaved_store - self.replication_handler.replication_data_handler = ReplicationDataHandler( + self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store )