Merge replication command and client handlers

pull/7185/head
Erik Johnston 2020-03-31 13:13:48 +01:00
parent a0063c9a15
commit 7e2593bc4d
4 changed files with 122 additions and 201 deletions

View File

@ -16,27 +16,12 @@
"""
import logging
from typing import Dict, List, Optional
from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import (
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
)
from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
MYPY = False
if MYPY:
@ -56,9 +41,9 @@ class ReplicationClientFactory(ReconnectingClientFactory):
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
def __init__(self, hs: "HomeServer", client_name, command_handler):
self.client_name = client_name
self.command_handler = ReplicationCommandHandler(hs, handler)
self.command_handler = command_handler
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
@ -87,169 +72,6 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
def __init__(self, hs: "HomeServer"):
self.presence_handler = hs.get_presence_handler()
self.data_handler = hs.get_replication_data_handler()
# The current connection. None if we are currently (re)connecting
self.connection = None
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self.data_handler.on_rdata(stream_name, token, rows)
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
await self.data_handler.on_position(stream_name, token)
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
for the sync with the given data.
Used by tests.
"""
d = self.awaiting_syncs.pop(data, None)
if d:
d.callback(data)
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
return self.data_handler.get_streams_to_replicate()
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return self.presence_handler.get_currently_syncing_users()
def send_command(self, cmd):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func, keys):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
[Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
if connection:
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()
class ReplicationDataHandler:
"""A replication data handler that calls slave data stores.
"""

View File

@ -17,15 +17,22 @@
"""
import logging
from typing import Any, Dict, List, Set
from typing import Any, Callable, Dict, List, Optional, Set
from prometheus_client import Counter
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
RemovePusherCommand,
SyncCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@ -42,8 +49,9 @@ class ReplicationCommandHandler:
"""Handles incoming commands from replication.
"""
def __init__(self, hs, handler):
self.handler = handler
def __init__(self, hs):
self.replication_data_handler = hs.get_replication_data_handler()
self.presence_handler = hs.get_presence_handler()
# Set of streams that we're currently catching up with.
self.streams_connecting = set() # type: Set[str]
@ -56,6 +64,22 @@ class ReplicationCommandHandler:
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
# The current connection. None if we are currently (re)connecting
self.connection = None
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@ -74,7 +98,19 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
await self.on_rdata(stream_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self.replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
@ -87,7 +123,9 @@ class ReplicationCommandHandler:
self.streams_connecting.add(cmd.stream_name)
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
current_token = self.replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
@ -101,39 +139,101 @@ class ReplicationCommandHandler:
current_token, cmd.token
)
if updates:
await self.handler.on_rdata(
await self.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd: SyncCommand):
self.handler.on_sync(cmd.data)
pass
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
""""Called when get a new REMOTE_SERVER_UP command."""
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return self.handler.get_currently_syncing_users()
return self.presence_handler.get_currently_syncing_users()
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
return self.handler.update_connection(connection)
self.connection = connection
def finished_connecting(self):
return self.handler.finished_connecting()
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)
def send_federation_ack(self, token: int):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))

View File

@ -87,10 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import (
ReplicationClientHandler,
ReplicationDataHandler,
)
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
@ -473,7 +471,7 @@ class HomeServer(object):
return ReadMarkerHandler(self)
def build_tcp_replication(self):
return ReplicationClientHandler(self)
return ReplicationCommandHandler(self)
def build_action_generator(self):
return ActionGenerator(self)

View File

@ -19,6 +19,7 @@ import synapse.handlers.set_password
import synapse.http.client
import synapse.notifier
import synapse.replication.tcp.client
import synapse.replication.tcp.handler
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@ -106,7 +107,7 @@ class HomeServer(object):
pass
def get_tcp_replication(
self,
) -> synapse.replication.tcp.client.ReplicationClientHandler:
) -> synapse.replication.tcp.handler.ReplicationCommandHandler:
pass
def get_replication_data_handler(
self,