Compare commits
11 Commits
5104d1673b
...
534bd868e5
Author | SHA1 | Date |
---|---|---|
Erik Johnston | 534bd868e5 | |
Erik Johnston | ca9778cedf | |
Erik Johnston | 1ebfa39a73 | |
Erik Johnston | bf99c8e87e | |
Erik Johnston | dc91879ee1 | |
Erik Johnston | cf57d56e39 | |
Erik Johnston | 8503564a77 | |
Erik Johnston | e16225ae28 | |
Erik Johnston | 0d6e7531fd | |
Erik Johnston | 23de3af9af | |
Erik Johnston | 730dbee169 |
|
@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
|||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer):
|
|||
def start_listening(self, listeners):
|
||||
pass
|
||||
|
||||
def build_tcp_replication(self):
|
||||
return AdminCmdReplicationHandler(self)
|
||||
|
||||
|
||||
class AdminCmdReplicationHandler(ReplicationClientHandler):
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
pass
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
return {}
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def export_data_command(hs, args):
|
||||
|
|
|
@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
|
|||
|
||||
|
||||
Structure of the module:
|
||||
* client.py - the client classes used for workers to connect to master
|
||||
* handler.py - the classes used to handle sending/receiving commands to
|
||||
replication
|
||||
* command.py - the definitions of all the valid commands
|
||||
* protocol.py - contains bot the client and server protocol implementations,
|
||||
these should not be used directly
|
||||
* resource.py - the server classes that accepts and handle client connections
|
||||
* streams.py - the definitons of all the valid streams
|
||||
* protocol.py - the TCP protocol classes
|
||||
* resource.py - handles streaming stream updates to replications
|
||||
* streams/ - the definitons of all the valid streams
|
||||
|
||||
|
||||
The general interaction of the classes are:
|
||||
|
||||
+---------------------+
|
||||
| ReplicationStreamer |
|
||||
+---------------------+
|
||||
|
|
||||
v
|
||||
+---------------------------+ +----------------------+
|
||||
| ReplicationCommandHandler |---->|ReplicationDataHandler|
|
||||
+---------------------------+ +----------------------+
|
||||
| ^
|
||||
v |
|
||||
+-------------+
|
||||
| Protocols |
|
||||
| (TCP/redis) |
|
||||
+-------------+
|
||||
|
||||
Where the ReplicationDataHandler (or subclasses) handles incoming stream
|
||||
updates.
|
||||
"""
|
||||
|
|
|
@ -16,16 +16,16 @@
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,14 +34,18 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
"""Factory for building connections to the master. Will reconnect if the
|
||||
connection is lost.
|
||||
|
||||
Accepts a handler that will be called when new data is available or data
|
||||
is required.
|
||||
Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
|
||||
"""
|
||||
|
||||
initialDelay = 0.1
|
||||
maxDelay = 1 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs: "HomeServer", client_name, command_handler):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
client_name: str,
|
||||
command_handler: "ReplicationCommandHandler",
|
||||
):
|
||||
self.client_name = client_name
|
||||
self.command_handler = command_handler
|
||||
self.server_name = hs.config.server_name
|
||||
|
@ -73,7 +77,10 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
|
||||
|
||||
class ReplicationDataHandler:
|
||||
"""A replication data handler that calls slave data stores.
|
||||
"""A replication data handler handles incoming stream updates from replication.
|
||||
|
||||
This instance notifies the slave data store about updates. Can be subclassed
|
||||
to handle updates in additional ways.
|
||||
"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
|
@ -112,3 +119,6 @@ class ReplicationDataHandler:
|
|||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
|
|
@ -13,8 +13,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A replication client for use by synapse workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
@ -51,13 +49,13 @@ class ReplicationCommandHandler:
|
|||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.replication_data_handler = hs.get_replication_data_handler()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self._replication_data_handler = hs.get_replication_data_handler()
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
|
||||
# Set of streams that we're currently catching up with.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
# Set of streams that we've caught up with.
|
||||
self._streams_connected = set() # type: Set[str]
|
||||
|
||||
self.streams = {
|
||||
self._streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
|
@ -65,23 +63,23 @@ class ReplicationCommandHandler:
|
|||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
self._pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
self._factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
self.connection = None
|
||||
self._connection = None
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
self._factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
stream_name = cmd.stream_name
|
||||
|
@ -93,13 +91,13 @@ class ReplicationCommandHandler:
|
|||
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
|
||||
raise
|
||||
|
||||
if cmd.token is None or stream_name in self.streams_connecting:
|
||||
if cmd.token is None or stream_name not in self._streams_connected:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
self._pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows = self._pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
|
@ -113,23 +111,26 @@ class ReplicationCommandHandler:
|
|||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self.replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self.streams.get(cmd.stream_name)
|
||||
stream = self._streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
return
|
||||
|
||||
# We're about to go and catch up with the stream, so mark as connecting
|
||||
# to stop RDATA being handled at the same time.
|
||||
self.streams_connecting.add(cmd.stream_name)
|
||||
# to stop RDATA being handled at the same time by removing stream from
|
||||
# list of connected streams. We also clear any batched up RDATA from
|
||||
# before we got the POSITION.
|
||||
self._streams_connected.discard(cmd.stream_name)
|
||||
self._pending_batches.clear()
|
||||
|
||||
# We protect catching up with a linearizer in case the replicaiton
|
||||
# connection reconnects under us.
|
||||
with await self._position_linearizer.queue(cmd.stream_name):
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self.replication_data_handler.get_streams_to_replicate().get(
|
||||
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_token is None:
|
||||
|
@ -153,32 +154,62 @@ class ReplicationCommandHandler:
|
|||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
self._streams_connected.add(cmd.stream_name)
|
||||
|
||||
# Handle any RDATA that came in while we were catching up.
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
rows = self._pending_batches.pop(cmd.stream_name, [])
|
||||
if rows:
|
||||
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||
# We need to make sure we filter out RDATA rows with a token less
|
||||
# than what we've caught up to. This is slightly fiddly because of
|
||||
# "batched" rows which have a `None` token, indicating that they
|
||||
# have the same token as the next row with a non-None token.
|
||||
#
|
||||
# We do this by walking the list backwards, first removing any RDATA
|
||||
# rows that are part of an uncompeted batch, then taking rows while
|
||||
# their token is either None or greater than where we've caught up
|
||||
# to.
|
||||
uncompleted_batch = []
|
||||
unfinished_batch = True
|
||||
filtered_rows = []
|
||||
for row in reversed(rows):
|
||||
if row.token is not None:
|
||||
unfinished_batch = False
|
||||
if cmd.token < row.token:
|
||||
filtered_rows.append(row)
|
||||
else:
|
||||
break
|
||||
elif unfinished_batch:
|
||||
uncompleted_batch.append(row)
|
||||
else:
|
||||
filtered_rows.append(row)
|
||||
|
||||
filtered_rows.reverse()
|
||||
uncompleted_batch.reverse()
|
||||
if uncompleted_batch:
|
||||
self._pending_batches[cmd.stream_name] = uncompleted_batch
|
||||
|
||||
await self.on_rdata(cmd.stream_name, rows[-1].token, filtered_rows)
|
||||
|
||||
async def on_SYNC(self, cmd: SyncCommand):
|
||||
pass
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users. (Overriden by the synchrotron's only)
|
||||
currently syncing users.
|
||||
"""
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
return self._presence_handler.get_currently_syncing_users()
|
||||
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
self.connection = connection
|
||||
self._connection = connection
|
||||
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
|
@ -189,15 +220,15 @@ class ReplicationCommandHandler:
|
|||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
if self._factory:
|
||||
self._factory.resetDelay()
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.send_command(cmd)
|
||||
if self._connection:
|
||||
self._connection.send_command(cmd)
|
||||
else:
|
||||
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
||||
|
||||
|
|
|
@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
|
|||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Set
|
||||
from typing import TYPE_CHECKING, DefaultDict, List
|
||||
|
||||
from six import iteritems
|
||||
|
||||
|
@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
|
|||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.types import Collection
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
if TYPE_CHECKING:
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
|
@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.streamer.lost_connection(self)
|
||||
|
||||
|
||||
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The interface for the handler that should be passed to
|
||||
ClientReplicationStreamProtocol
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_sync(self, data):
|
||||
"""Called when get a new SYNC command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streams_to_replicate(self):
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
|
@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
client_name: str,
|
||||
server_name: str,
|
||||
clock: Clock,
|
||||
command_handler,
|
||||
command_handler: "ReplicationCommandHandler",
|
||||
):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||
|
||||
|
@ -560,17 +493,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.server_name = server_name
|
||||
self.handler = command_handler
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
# Set of streams that we're currently catching up with.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
|
@ -592,7 +514,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
async def handle_command(self, cmd: Command):
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
Delegates to `command_handler.on_<COMMAND>`, which must return an
|
||||
awaitable.
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
|
|
|
@ -57,8 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
|||
# We now do some gut wrenching so that we have a client that is based
|
||||
# off of the slave store rather than the main store.
|
||||
self.replication_handler = ReplicationCommandHandler(self.hs)
|
||||
self.replication_handler.store = self.slaved_store
|
||||
self.replication_handler.replication_data_handler = ReplicationDataHandler(
|
||||
self.replication_handler._replication_data_handler = ReplicationDataHandler(
|
||||
self.slaved_store
|
||||
)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
"""Base class for tests of the replication streams"""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.test_handler = Mock(wraps=TestReplicationClientHandler())
|
||||
self.test_handler = Mock(wraps=TestReplicationDataHandler())
|
||||
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
@ -75,7 +75,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.pump(0.1)
|
||||
|
||||
|
||||
class TestReplicationClientHandler:
|
||||
class TestReplicationDataHandler:
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
def __init__(self):
|
||||
self.streams = set()
|
||||
self._received_rdata_rows = []
|
||||
|
|
Loading…
Reference in New Issue