Fixup protocol.py

pull/7185/head
Erik Johnston 2020-04-01 17:15:03 +01:00
parent e16225ae28
commit 8503564a77
1 changed files with 6 additions and 83 deletions

View File

@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping > ERROR server stopping
* connection closed by server * * connection closed by server *
""" """
import abc
import fcntl import fcntl
import logging import logging
import struct import struct
from collections import defaultdict from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set from typing import TYPE_CHECKING, DefaultDict, List
from six import iteritems from six import iteritems
@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
SyncCommand, SyncCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection from synapse.types import Collection
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
MYPY = False if TYPE_CHECKING:
if MYPY: from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.server import HomeServer from synapse.server import HomeServer
@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self) self.streamer.lost_connection(self)
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
The interface for the handler that should be passed to
ClientReplicationStreamProtocol
"""
@abc.abstractmethod
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
raise NotImplementedError()
@abc.abstractmethod
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@abc.abstractmethod
def on_sync(self, data):
"""Called when get a new SYNC command."""
raise NotImplementedError()
@abc.abstractmethod
async def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
raise NotImplementedError()
@abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
raise NotImplementedError()
@abc.abstractmethod
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users."""
raise NotImplementedError()
@abc.abstractmethod
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
raise NotImplementedError()
@abc.abstractmethod
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
raise NotImplementedError()
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
client_name: str, client_name: str,
server_name: str, server_name: str,
clock: Clock, clock: Clock,
command_handler, command_handler: "ReplicationCommandHandler",
): ):
BaseReplicationStreamProtocol.__init__(self, clock) BaseReplicationStreamProtocol.__init__(self, clock)
@ -560,17 +493,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name self.server_name = server_name
self.handler = command_handler self.handler = command_handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of streams that we're currently catching up with.
self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self): def connectionMade(self):
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
@ -592,7 +514,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def handle_command(self, cmd: Command): async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream. """Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable. Delegates to `command_handler.on_<COMMAND>`, which must return an
awaitable.
Args: Args:
cmd: received command cmd: received command