Optimise queueing of inbound replication commands (#7861)

When we get behind on replication, we tend to stack up background processes
behind a linearizer. Bg processes are heavy (particularly with respect to
prometheus metrics) and linearizers aren't terribly efficient once the queue
gets long either.

A better approach is to maintain a queue of requests to be processed, and
nominate a single process to work its way through the queue.

Fixes: #7444
pull/7803/head
Richard van der Hoff 2020-07-16 15:49:37 +01:00 committed by GitHub
parent 346476df21
commit e5300063ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 212 additions and 112 deletions

1
changelog.d/7861.misc Normal file
View File

@ -0,0 +1 @@
Optimise queueing of inbound replication commands.

View File

@ -14,9 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Counter
from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
@ -44,7 +56,6 @@ from synapse.replication.tcp.streams import (
Stream,
TypingStream,
)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@ -62,6 +73,12 @@ invalidate_cache_counter = Counter(
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
]
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@ -116,10 +133,6 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
@ -131,10 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming stream names that are coming from
# that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@ -142,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_stream
self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the
# connection that they arrived on.
self._command_queues_by_stream = {
stream_name: _StreamCommandQueue() for stream_name in self._streams
}
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
"Number of inbound RDATA/POSITION commands queued for processing",
["stream_name"],
lambda: {
(stream_name,): len(queue)
for stream_name, queue in self._command_queues_by_stream.items()
},
)
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@ -152,6 +187,64 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
async def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Adds the given command to the per-stream queue, and processes the queue if
necessary
"""
stream_name = cmd.stream_name
queue = self._command_queues_by_stream.get(stream_name)
if queue is None:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return
# if we're already processing this stream, stick the new command in the
# queue, and we're done.
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return
# otherwise, process the new command.
# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)
# now process any other commands that have built up while we were
# dealing with that one.
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)
finally:
self._processing_streams.discard(stream_name)
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
await self._process_position(stream_name, conn, cmd)
elif isinstance(cmd, RdataCommand):
await self._process_rdata(stream_name, conn, cmd)
else:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
@ -285,63 +378,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise
# We linearize here for two reasons:
# We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self._add_command_to_stream_queue(conn, cmd)
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got RDATA for unknown stream: %s", stream_name)
return
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
Called after the command has been popped off the queue of inbound commands
"""
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception as e:
raise Exception(
"Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
) from e
# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
return
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
stream = self._streams[stream_name]
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@ -367,67 +468,65 @@ class ReplicationCommandHandler:
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", stream_name)
return
await self._add_command_to_stream_queue(conn, cmd)
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])
Called after the command has been popped off the queue of inbound commands
"""
stream = self._streams[stream_name]
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])
# TODO: add some tests for this
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
# TODO: add some tests for this
# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand