Merge server command processing into ReplicationCommandHandler

pull/7187/head
Erik Johnston 2020-03-31 14:25:09 +01:00
parent e6c25e0858
commit cadb3f57dd
3 changed files with 153 additions and 238 deletions

View File

@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set
from prometheus_client import Counter from prometheus_client import Counter
from synapse.metrics import LaterGauge
from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command, Command,
FederationAckCommand, FederationAckCommand,
InvalidateCacheCommand, InvalidateCacheCommand,
@ -28,6 +30,7 @@ from synapse.replication.tcp.commands import (
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
RemovePusherCommand, RemovePusherCommand,
ReplicateCommand,
SyncCommand, SyncCommand,
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
@ -42,6 +45,13 @@ logger = logging.getLogger(__name__)
inbound_rdata_count = Counter( inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
) )
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
class ReplicationCommandHandler: class ReplicationCommandHandler:
@ -52,6 +62,8 @@ class ReplicationCommandHandler:
def __init__(self, hs): def __init__(self, hs):
self._replication_data_handler = hs.get_replication_data_handler() self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
self._notifier = hs.get_notifier()
# Set of streams that we've caught up with. # Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str] self._streams_connected = set() # type: Set[str]
@ -70,7 +82,25 @@ class ReplicationCommandHandler:
self._factory = None # type: Optional[ReplicationClientFactory] self._factory = None # type: Optional[ReplicationClientFactory]
# The currently connected connections. # The currently connected connections.
self._connections = [] self._connections = [] # type: List[Any]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self._connections),
)
self._is_master = hs.config.worker_app is None
self._federation_sender = None
if self._is_master and not hs.config.send_federation:
self._federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs): def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server """Helper method to start a replication connection to the remote server
@ -82,6 +112,73 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory) hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
return
if not self._connections:
raise Exception("Not connected")
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
await self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand): async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc() inbound_rdata_count.labels(stream_name).inc()
@ -174,6 +271,9 @@ class ReplicationCommandHandler:
""""Called when get a new REMOTE_SERVER_UP command.""" """"Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data) self._replication_data_handler.on_remote_server_up(cmd.data)
if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called """Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the when a connection has been established and we need to send the
@ -261,3 +361,10 @@ class ReplicationCommandHandler:
def send_remote_server_up(self, server: str): def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server)) self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))

View File

@ -69,12 +69,8 @@ from synapse.replication.tcp.commands import (
ErrorCommand, ErrorCommand,
NameCommand, NameCommand,
PingCommand, PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand, ReplicateCommand,
ServerCommand, ServerCommand,
SyncCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.types import Collection from synapse.types import Collection
@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000 max_line_buffer = 10000
def __init__(self, clock): def __init__(self, clock, handler):
self.clock = clock self.clock = clock
self.handler = handler
self.last_received_command = self.clock.time_msec() self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0 self.last_sent_command = 0
@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out. # can time us out.
self.send_command(PingCommand(self.clock.time_msec())) self.send_command(PingCommand(self.clock.time_msec()))
self.handler.new_connection(self)
def send_ping(self): def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection """Periodically sends a ping and checks if we should close the connection
due to the other side timing out. due to the other side timing out.
@ -248,8 +247,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
Args: Args:
cmd: received command cmd: received command
""" """
handler = getattr(self, "on_%s" % (cmd.NAME,)) handled = False
await handler(cmd)
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def close(self): def close(self):
logger.warning("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
@ -378,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED self.state = ConnectionStates.CLOSED
self.pending_commands = [] self.pending_commands = []
self.handler.lost_connection(self)
if self.transport: if self.transport:
self.transport.unregisterProducer() self.transport.unregisterProducer()
@ -404,74 +420,19 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
def __init__(self, server_name, clock, streamer): def __init__(self, server_name, clock, handler):
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
self.server_name = server_name self.server_name = server_name
self.streamer = streamer
def connectionMade(self): def connectionMade(self):
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
async def on_NAME(self, cmd): async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data) logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data self.name = cmd.data
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd):
await self.streamer.on_clear_user_syncs(cmd.instance_id)
async def on_REPLICATE(self, cmd):
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
current_token = self.streamer.get_stream_token(stream_name)
self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@ -485,13 +446,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
clock: Clock, clock: Clock,
command_handler: "ReplicationCommandHandler", command_handler: "ReplicationCommandHandler",
): ):
BaseReplicationStreamProtocol.__init__(self, clock) BaseReplicationStreamProtocol.__init__(self, clock, command_handler)
self.instance_id = hs.get_instance_id() self.instance_id = hs.get_instance_id()
self.client_name = client_name self.client_name = client_name
self.server_name = server_name self.server_name = server_name
self.handler = command_handler
def connectionMade(self): def connectionMade(self):
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
@ -507,36 +467,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
for user_id in currently_syncing: for user_id in currently_syncing:
self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.new_connection(self)
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
Delegates to `command_handler.on_<COMMAND>`, which must return an
awaitable.
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
async def on_SERVER(self, cmd): async def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@ -549,10 +479,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand()) self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.handler.lost_connection(self)
# The following simply registers metrics for the replication connections # The following simply registers metrics for the replication connections

View File

@ -17,7 +17,7 @@
import logging import logging
import random import random
from typing import Any, Dict, List from typing import Dict
from six import itervalues from six import itervalues
@ -25,24 +25,14 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from .protocol import ServerReplicationStreamProtocol from synapse.util.metrics import Measure
from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter( stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
) )
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,13 +42,18 @@ class ReplicationStreamProtocolFactory(Factory):
""" """
def __init__(self, hs): def __init__(self, hs):
self.streamer = hs.get_replication_streamer() self.handler = hs.get_tcp_replication()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.hs = hs
# Ensure the replication streamer is started if we register a
# replication server endpoint.
hs.get_replication_streamer()
def buildProtocol(self, addr): def buildProtocol(self, addr):
return ServerReplicationStreamProtocol( return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.streamer self.server_name, self.clock, self.handler
) )
@ -78,16 +73,6 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
# List of streams that clients can subscribe to. # List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been # We only support federation stream if federation sending hase been
# disabled on the master. # disabled on the master.
@ -104,18 +89,12 @@ class ReplicationStreamer(object):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke) self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates # Keeps track of whether we are currently checking for updates
self.is_looping = False self.is_looping = False
self.pending_updates = False self.pending_updates = False
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) self.client = hs.get_tcp_replication()
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance. """Get a mapp from stream name to stream instance.
@ -129,7 +108,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed is currently being executed, so that nothing gets missed
""" """
if not self.connections: if not self.client.connected():
# Don't bother if nothing is listening. We still need to advance # Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever # the stream tokens otherwise they'll fall beihind forever
for stream in self.streams: for stream in self.streams:
@ -186,9 +165,7 @@ class ReplicationStreamer(object):
raise raise
logger.debug( logger.debug(
"Sending %d updates to %d connections", "Sending %d updates", len(updates),
len(updates),
len(self.connections),
) )
if updates: if updates:
@ -204,112 +181,17 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details. # token. See RdataCommand for more details.
batched_updates = _batch_updates(updates) batched_updates = _batch_updates(updates)
for conn in self.connections: for token, row in batched_updates:
for token, row in batched_updates: try:
try: self.client.stream_update(stream.NAME, token, row)
conn.stream_update(stream.NAME, token, row) except Exception:
except Exception: logger.exception("Failed to replicate")
logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop") logger.debug("No more pending updates, breaking poke loop")
finally: finally:
self.pending_updates = False self.pending_updates = False
self.is_looping = False self.is_looping = False
def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
stream = self.streams_by_name.get(stream_name, None)
if not stream:
raise Exception("unknown stream %s", stream_name)
return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
"""We've received an ack for federation stream from a client.
"""
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
instance_id, user_id, is_syncing, last_sync_ms
)
async def on_clear_user_syncs(self, instance_id):
"""A replication client wants us to drop all their UserSync data.
"""
await self.presence_handler.update_external_syncs_clear(instance_id)
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
Used in tests.
"""
for conn in self.connections:
conn.send_sync(data)
def new_connection(self, connection):
"""A new client connection has been established
"""
self.connections.append(connection)
def lost_connection(self, connection):
"""A client connection has been lost
"""
try:
self.connections.remove(connection)
except ValueError:
pass
def _batch_updates(updates): def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to """Takes a list of updates of form [(token, row)] and sets the token to