Merge replication command and client handlers
parent
a0063c9a15
commit
7e2593bc4d
|
@ -16,27 +16,12 @@
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractReplicationClientHandler,
|
||||
ClientReplicationStreamProtocol,
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
|
@ -56,9 +41,9 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
initialDelay = 0.1
|
||||
maxDelay = 1 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
|
||||
def __init__(self, hs: "HomeServer", client_name, command_handler):
|
||||
self.client_name = client_name
|
||||
self.command_handler = ReplicationCommandHandler(hs, handler)
|
||||
self.command_handler = command_handler
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||
|
@ -87,169 +72,6 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||
|
||||
|
||||
class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||
|
||||
By default proxies incoming replication data to the SlaveStore.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.data_handler = hs.get_replication_data_handler()
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
self.connection = None
|
||||
|
||||
# Any pending commands to be sent once a new connection has been
|
||||
# established
|
||||
self.pending_commands = [] # type: List[Command]
|
||||
|
||||
# Map from string -> deferred, to wake up when receiveing a SYNC with
|
||||
# the given string.
|
||||
# Used for tests.
|
||||
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self.data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data. By default this just pokes
|
||||
the slave store.
|
||||
|
||||
Can be overriden in subclasses to handle more.
|
||||
"""
|
||||
await self.data_handler.on_position(stream_name, token)
|
||||
|
||||
def on_sync(self, data):
|
||||
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||
for the sync with the given data.
|
||||
|
||||
Used by tests.
|
||||
"""
|
||||
d = self.awaiting_syncs.pop(data, None)
|
||||
if d:
|
||||
d.callback(data)
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
|
||||
return self.data_handler.get_streams_to_replicate()
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users. (Overriden by the synchrotron's only)
|
||||
"""
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
def send_command(self, cmd):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.send_command(cmd)
|
||||
else:
|
||||
logger.warning("Queuing command as not connected: %r", cmd.NAME)
|
||||
self.pending_commands.append(cmd)
|
||||
|
||||
def send_federation_ack(self, token):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
)
|
||||
|
||||
def send_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func, keys):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def await_sync(self, data):
|
||||
"""Returns a deferred that is resolved when we receive a SYNC command
|
||||
with given data.
|
||||
|
||||
[Not currently] used by tests.
|
||||
"""
|
||||
return self.awaiting_syncs.setdefault(data, defer.Deferred())
|
||||
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
self.connection = connection
|
||||
if connection:
|
||||
for cmd in self.pending_commands:
|
||||
connection.send_command(cmd)
|
||||
self.pending_commands = []
|
||||
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
logger.info("Finished connecting to server")
|
||||
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
|
||||
|
||||
class ReplicationDataHandler:
|
||||
"""A replication data handler that calls slave data stores.
|
||||
"""
|
||||
|
|
|
@ -17,15 +17,22 @@
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.replication.tcp.client import ReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import (
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
SyncCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
|
||||
|
@ -42,8 +49,9 @@ class ReplicationCommandHandler:
|
|||
"""Handles incoming commands from replication.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, handler):
|
||||
self.handler = handler
|
||||
def __init__(self, hs):
|
||||
self.replication_data_handler = hs.get_replication_data_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
# Set of streams that we're currently catching up with.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
|
@ -56,6 +64,22 @@ class ReplicationCommandHandler:
|
|||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
self.connection = None
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
@ -74,7 +98,19 @@ class ReplicationCommandHandler:
|
|||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
await self.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self.replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self.streams.get(cmd.stream_name)
|
||||
|
@ -87,7 +123,9 @@ class ReplicationCommandHandler:
|
|||
self.streams_connecting.add(cmd.stream_name)
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
|
||||
current_token = self.replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_token is None:
|
||||
logger.warning(
|
||||
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
|
||||
|
@ -101,39 +139,101 @@ class ReplicationCommandHandler:
|
|||
current_token, cmd.token
|
||||
)
|
||||
if updates:
|
||||
await self.handler.on_rdata(
|
||||
await self.on_rdata(
|
||||
cmd.stream_name,
|
||||
current_token,
|
||||
[stream.parse_row(update[1]) for update in updates],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
|
||||
# Handle any RDATA that came in while we were catching up.
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
if rows:
|
||||
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||
|
||||
async def on_SYNC(self, cmd: SyncCommand):
|
||||
self.handler.on_sync(cmd.data)
|
||||
pass
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.handler.on_remote_server_up(cmd.data)
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users. (Overriden by the synchrotron's only)
|
||||
"""
|
||||
return self.handler.get_currently_syncing_users()
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
return self.handler.update_connection(connection)
|
||||
self.connection = connection
|
||||
|
||||
def finished_connecting(self):
|
||||
return self.handler.finished_connecting()
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
logger.info("Finished connecting to server")
|
||||
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.send_command(cmd)
|
||||
else:
|
||||
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
||||
|
||||
def send_federation_ack(self, token: int):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(
|
||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||
):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
)
|
||||
|
||||
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
device_id: str,
|
||||
last_seen: int,
|
||||
):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
|
|
@ -87,10 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
|||
from synapse.notifier import Notifier
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.replication.tcp.client import (
|
||||
ReplicationClientHandler,
|
||||
ReplicationDataHandler,
|
||||
)
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||
from synapse.rest.media.v1.media_repository import (
|
||||
MediaRepository,
|
||||
|
@ -473,7 +471,7 @@ class HomeServer(object):
|
|||
return ReadMarkerHandler(self)
|
||||
|
||||
def build_tcp_replication(self):
|
||||
return ReplicationClientHandler(self)
|
||||
return ReplicationCommandHandler(self)
|
||||
|
||||
def build_action_generator(self):
|
||||
return ActionGenerator(self)
|
||||
|
|
|
@ -19,6 +19,7 @@ import synapse.handlers.set_password
|
|||
import synapse.http.client
|
||||
import synapse.notifier
|
||||
import synapse.replication.tcp.client
|
||||
import synapse.replication.tcp.handler
|
||||
import synapse.rest.media.v1.media_repository
|
||||
import synapse.server_notices.server_notices_manager
|
||||
import synapse.server_notices.server_notices_sender
|
||||
|
@ -106,7 +107,7 @@ class HomeServer(object):
|
|||
pass
|
||||
def get_tcp_replication(
|
||||
self,
|
||||
) -> synapse.replication.tcp.client.ReplicationClientHandler:
|
||||
) -> synapse.replication.tcp.handler.ReplicationCommandHandler:
|
||||
pass
|
||||
def get_replication_data_handler(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue