Add ReplicationCommandHandler.

The intention is to move all command processing logic to there, out of
protocol, client and resource modules. Currently we only pull out client
command processing from protocol and delegate to
ReplicationClientHandler.
pull/7185/head
Erik Johnston 2020-03-31 11:16:13 +01:00
parent 8f1a87886f
commit a0063c9a15
3 changed files with 175 additions and 78 deletions

View File

@ -20,7 +20,9 @@ from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import (
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
@ -56,7 +58,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.command_handler = ReplicationCommandHandler(hs, handler)
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
@ -69,7 +71,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.hs, self.client_name, self.server_name, self._clock, self.handler,
self.hs,
self.client_name,
self.server_name,
self._clock,
self.command_handler,
)
def clientConnectionLost(self, connector, reason):

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import logging
from typing import Any, Dict, List, Set
from prometheus_client import Counter
from synapse.replication.tcp.commands import (
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
SyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
logger = logging.getLogger(__name__)
# number of updates received for each RDATA stream
inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)
class ReplicationCommandHandler:
"""Handles incoming commands from replication.
"""
def __init__(self, hs, handler):
self.handler = handler
# Set of streams that we're currently catching up with.
self.streams_connecting = set() # type: Set[str]
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise
if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time.
self.streams_connecting.add(cmd.stream_name)
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
if updates:
await self.handler.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd: SyncCommand):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return self.handler.get_currently_syncing_users()
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
return self.handler.update_connection(connection)
def finished_connecting(self):
return self.handler.finished_connecting()

View File

@ -550,7 +550,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
client_name: str,
server_name: str,
clock: Clock,
handler: AbstractReplicationClientHandler,
command_handler,
):
BaseReplicationStreamProtocol.__init__(self, clock)
@ -558,7 +558,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.client_name = client_name
self.server_name = server_name
self.handler = handler
self.handler = command_handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
@ -589,80 +589,37 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.handler.update_connection(self)
self.handler.finished_connecting()
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# We're about to go and catch up with the stream, so mark as connecting
# to stop RDATA being handled at the same time.
self.streams_connecting.add(cmd.stream_name)
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
if updates:
await self.handler.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self):
"""Send the subscription request to the server
"""
@ -758,8 +715,3 @@ tcp_outbound_commands = LaterGauge(
for k, count in iteritems(p.outbound_commands_counter)
},
)
# number of updates received for each RDATA stream
inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)