Fixup handler

pull/7185/head
Erik Johnston 2020-04-01 17:05:02 +01:00
parent 0d6e7531fd
commit e16225ae28
2 changed files with 27 additions and 30 deletions

View File

@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import logging
from typing import Any, Callable, Dict, List, Optional, Set
@ -51,13 +49,13 @@ class ReplicationCommandHandler:
"""
def __init__(self, hs):
self.replication_data_handler = hs.get_replication_data_handler()
self.presence_handler = hs.get_presence_handler()
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
# Set of streams that we're currently catching up with.
self.streams_connecting = set() # type: Set[str]
self._streams_connecting = set() # type: Set[str]
self.streams = {
self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
@ -65,23 +63,23 @@ class ReplicationCommandHandler:
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
self._factory = None # type: Optional[ReplicationClientFactory]
# The current connection. None if we are currently (re)connecting
self.connection = None
self._connection = None
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
self._factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
hs.get_reactor().connectTCP(host, port, self._factory)
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
@ -93,13 +91,13 @@ class ReplicationCommandHandler:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise
if cmd.token is None or stream_name in self.streams_connecting:
if cmd.token is None or stream_name in self._streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
@ -113,23 +111,23 @@ class ReplicationCommandHandler:
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self.replication_data_handler.on_rdata(stream_name, token, rows)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time.
self.streams_connecting.add(cmd.stream_name)
self._streams_connecting.add(cmd.stream_name)
# We protect catching up with a linearizer in case the replicaiton
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
# Find where we previously streamed up to.
current_token = self.replication_data_handler.get_streams_to_replicate().get(
current_token = self._replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
@ -153,12 +151,12 @@ class ReplicationCommandHandler:
)
# We've now caught up to position sent to us, notify handler.
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
self._streams_connecting.discard(cmd.stream_name)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
rows = self._pending_batches.pop(cmd.stream_name, [])
if rows:
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
@ -171,14 +169,14 @@ class ReplicationCommandHandler:
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
currently syncing users.
"""
return self.presence_handler.get_currently_syncing_users()
return self._presence_handler.get_currently_syncing_users()
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
self._connection = connection
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
@ -189,15 +187,15 @@ class ReplicationCommandHandler:
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()
if self._factory:
self._factory.resetDelay()
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
if self._connection:
self._connection.send_command(cmd)
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)

View File

@ -57,8 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
# We now do some gut wrenching so that we have a client that is based
# off of the slave store rather than the main store.
self.replication_handler = ReplicationCommandHandler(self.hs)
self.replication_handler.store = self.slaved_store
self.replication_handler.replication_data_handler = ReplicationDataHandler(
self.replication_handler._replication_data_handler = ReplicationDataHandler(
self.slaved_store
)