Handle replication commands synchronously where possible (#7876)

Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
pull/7967/head
Richard van der Hoff 2020-07-27 18:54:43 +01:00 committed by GitHub
parent 7c2e2c2077
commit f57b99af22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 86 deletions

1
changelog.d/7876.bugfix Normal file
View File

@ -0,0 +1 @@
Fix an `AssertionError` exception introduced in v1.18.0rc1.

1
changelog.d/7876.misc Normal file
View File

@ -0,0 +1 @@
Further optimise queueing of inbound replication commands.

View File

@ -16,6 +16,7 @@
import logging
from typing import (
Any,
Awaitable,
Dict,
Iterable,
Iterator,
@ -33,6 +34,7 @@ from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
@ -152,7 +154,7 @@ class ReplicationCommandHandler:
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_stream
# the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the
@ -185,7 +187,7 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
async def _add_command_to_stream_queue(
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@ -199,33 +201,34 @@ class ReplicationCommandHandler:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return
# if we're already processing this stream, stick the new command in the
# queue, and we're done.
queue.append((cmd, conn))
# if we're already processing this stream, there's nothing more to do:
# the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return
# otherwise, process the new command.
# fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, stream_name
)
# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
async def _unsafe_process_queue(self, stream_name: str):
"""Processes the command queue for the given stream, until it is empty
Does not check if there is already a thread processing the queue, hence "unsafe"
"""
assert stream_name not in self._processing_streams
self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)
# now process any other commands that have built up while we were
# dealing with that one.
queue = self._command_queues_by_stream.get(stream_name)
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)
finally:
self._processing_streams.discard(stream_name)
@ -299,7 +302,7 @@ class ReplicationCommandHandler:
"""
return self._streams_to_replicate
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
@ -318,57 +321,73 @@ class ReplicationCommandHandler:
)
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
if self._is_master:
await self._presence_handler.update_external_syncs_row(
return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
else:
return None
async def on_CLEAR_USER_SYNC(
def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
) -> Optional[Awaitable[None]]:
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
async def on_REMOVE_PUSHER(
def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
return self._handle_remove_pusher(cmd)
else:
return None
self._notifier.on_new_replication_data()
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
self._notifier.on_new_replication_data()
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
return self._handle_user_ip(cmd)
else:
return None
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def _handle_user_ip(self, cmd: UserIpCommand):
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@ -382,7 +401,7 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@ -459,14 +478,14 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@ -526,9 +545,7 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)

View File

@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
self._logging_context = BackgroundProcessLoggingContext(
"replication_command_handler-%s" % self.conn_id
)
ctx_name = "replication-conn-%s" % self.conn_id
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
self._logging_context.request = ctx_name
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)
async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
return an Awaitable).
This allows for protocol level handling of commands (e.g. PINGs), before
delegating to the handler.
Args:
cmd: received command
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
handled = True
if not handled:
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
async def on_PING(self, line):
def on_PING(self, line):
self.received_ping = True
async def on_ERROR(self, cmd):
def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
async def on_NAME(self, cmd):
def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
async def on_SERVER(self, cmd):
def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)
async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
Awaitable).
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for redis
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
handled = True
if not handled:
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
def connectionLost(self, reason):
logger.info("Lost connection to redis")