Compare commits
11 Commits
5104d1673b
...
534bd868e5
Author | SHA1 | Date |
---|---|---|
Erik Johnston | 534bd868e5 | |
Erik Johnston | ca9778cedf | |
Erik Johnston | 1ebfa39a73 | |
Erik Johnston | bf99c8e87e | |
Erik Johnston | dc91879ee1 | |
Erik Johnston | cf57d56e39 | |
Erik Johnston | 8503564a77 | |
Erik Johnston | e16225ae28 | |
Erik Johnston | 0d6e7531fd | |
Erik Johnston | 23de3af9af | |
Erik Johnston | 730dbee169 |
|
@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
from synapse.replication.slave.storage.room import RoomStore
|
from synapse.replication.slave.storage.room import RoomStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer):
|
||||||
def start_listening(self, listeners):
|
def start_listening(self, listeners):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def build_tcp_replication(self):
|
|
||||||
return AdminCmdReplicationHandler(self)
|
|
||||||
|
|
||||||
|
|
||||||
class AdminCmdReplicationHandler(ReplicationClientHandler):
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def export_data_command(hs, args):
|
def export_data_command(hs, args):
|
||||||
|
|
|
@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
|
||||||
|
|
||||||
|
|
||||||
Structure of the module:
|
Structure of the module:
|
||||||
* client.py - the client classes used for workers to connect to master
|
* handler.py - the classes used to handle sending/receiving commands to
|
||||||
|
replication
|
||||||
* command.py - the definitions of all the valid commands
|
* command.py - the definitions of all the valid commands
|
||||||
* protocol.py - contains bot the client and server protocol implementations,
|
* protocol.py - the TCP protocol classes
|
||||||
these should not be used directly
|
* resource.py - handles streaming stream updates to replications
|
||||||
* resource.py - the server classes that accepts and handle client connections
|
* streams/ - the definitons of all the valid streams
|
||||||
* streams.py - the definitons of all the valid streams
|
|
||||||
|
|
||||||
|
|
||||||
|
The general interaction of the classes are:
|
||||||
|
|
||||||
|
+---------------------+
|
||||||
|
| ReplicationStreamer |
|
||||||
|
+---------------------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+---------------------------+ +----------------------+
|
||||||
|
| ReplicationCommandHandler |---->|ReplicationDataHandler|
|
||||||
|
+---------------------------+ +----------------------+
|
||||||
|
| ^
|
||||||
|
v |
|
||||||
|
+-------------+
|
||||||
|
| Protocols |
|
||||||
|
| (TCP/redis) |
|
||||||
|
+-------------+
|
||||||
|
|
||||||
|
Where the ReplicationDataHandler (or subclasses) handles incoming stream
|
||||||
|
updates.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,16 +16,16 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||||
|
|
||||||
MYPY = False
|
if TYPE_CHECKING:
|
||||||
if MYPY:
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -34,14 +34,18 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||||
"""Factory for building connections to the master. Will reconnect if the
|
"""Factory for building connections to the master. Will reconnect if the
|
||||||
connection is lost.
|
connection is lost.
|
||||||
|
|
||||||
Accepts a handler that will be called when new data is available or data
|
Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
|
||||||
is required.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
initialDelay = 0.1
|
initialDelay = 0.1
|
||||||
maxDelay = 1 # Try at least once every N seconds
|
maxDelay = 1 # Try at least once every N seconds
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", client_name, command_handler):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
client_name: str,
|
||||||
|
command_handler: "ReplicationCommandHandler",
|
||||||
|
):
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self.command_handler = command_handler
|
self.command_handler = command_handler
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
|
@ -73,7 +77,10 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||||
|
|
||||||
|
|
||||||
class ReplicationDataHandler:
|
class ReplicationDataHandler:
|
||||||
"""A replication data handler that calls slave data stores.
|
"""A replication data handler handles incoming stream updates from replication.
|
||||||
|
|
||||||
|
This instance notifies the slave data store about updates. Can be subclassed
|
||||||
|
to handle updates in additional ways.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, store: BaseSlavedStore):
|
def __init__(self, store: BaseSlavedStore):
|
||||||
|
@ -112,3 +119,6 @@ class ReplicationDataHandler:
|
||||||
|
|
||||||
async def on_position(self, stream_name: str, token: int):
|
async def on_position(self, stream_name: str, token: int):
|
||||||
self.store.process_replication_rows(stream_name, token, [])
|
self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
|
def on_remote_server_up(self, server: str):
|
||||||
|
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
|
|
|
@ -13,8 +13,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""A replication client for use by synapse workers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
@ -51,13 +49,13 @@ class ReplicationCommandHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.replication_data_handler = hs.get_replication_data_handler()
|
self._replication_data_handler = hs.get_replication_data_handler()
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
# Set of streams that we're currently catching up with.
|
# Set of streams that we've caught up with.
|
||||||
self.streams_connecting = set() # type: Set[str]
|
self._streams_connected = set() # type: Set[str]
|
||||||
|
|
||||||
self.streams = {
|
self._streams = {
|
||||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||||
} # type: Dict[str, Stream]
|
} # type: Dict[str, Stream]
|
||||||
|
|
||||||
|
@ -65,23 +63,23 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# Map of stream to batched updates. See RdataCommand for info on how
|
# Map of stream to batched updates. See RdataCommand for info on how
|
||||||
# batching works.
|
# batching works.
|
||||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
self._pending_batches = {} # type: Dict[str, List[Any]]
|
||||||
|
|
||||||
# The factory used to create connections.
|
# The factory used to create connections.
|
||||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
self._factory = None # type: Optional[ReplicationClientFactory]
|
||||||
|
|
||||||
# The current connection. None if we are currently (re)connecting
|
# The current connection. None if we are currently (re)connecting
|
||||||
self.connection = None
|
self._connection = None
|
||||||
|
|
||||||
def start_replication(self, hs):
|
def start_replication(self, hs):
|
||||||
"""Helper method to start a replication connection to the remote server
|
"""Helper method to start a replication connection to the remote server
|
||||||
using TCP.
|
using TCP.
|
||||||
"""
|
"""
|
||||||
client_name = hs.config.worker_name
|
client_name = hs.config.worker_name
|
||||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
self._factory = ReplicationClientFactory(hs, client_name, self)
|
||||||
host = hs.config.worker_replication_host
|
host = hs.config.worker_replication_host
|
||||||
port = hs.config.worker_replication_port
|
port = hs.config.worker_replication_port
|
||||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||||
|
|
||||||
async def on_RDATA(self, cmd: RdataCommand):
|
async def on_RDATA(self, cmd: RdataCommand):
|
||||||
stream_name = cmd.stream_name
|
stream_name = cmd.stream_name
|
||||||
|
@ -93,13 +91,13 @@ class ReplicationCommandHandler:
|
||||||
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
|
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if cmd.token is None or stream_name in self.streams_connecting:
|
if cmd.token is None or stream_name not in self._streams_connected:
|
||||||
# I.e. this is part of a batch of updates for this stream. Batch
|
# I.e. this is part of a batch of updates for this stream. Batch
|
||||||
# until we get an update for the stream with a non None token
|
# until we get an update for the stream with a non None token
|
||||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
self._pending_batches.setdefault(stream_name, []).append(row)
|
||||||
else:
|
else:
|
||||||
# Check if this is the last of a batch of updates
|
# Check if this is the last of a batch of updates
|
||||||
rows = self.pending_batches.pop(stream_name, [])
|
rows = self._pending_batches.pop(stream_name, [])
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
await self.on_rdata(stream_name, cmd.token, rows)
|
await self.on_rdata(stream_name, cmd.token, rows)
|
||||||
|
|
||||||
|
@ -113,23 +111,26 @@ class ReplicationCommandHandler:
|
||||||
Stream.parse_row.
|
Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||||
await self.replication_data_handler.on_rdata(stream_name, token, rows)
|
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||||
|
|
||||||
async def on_POSITION(self, cmd: PositionCommand):
|
async def on_POSITION(self, cmd: PositionCommand):
|
||||||
stream = self.streams.get(cmd.stream_name)
|
stream = self._streams.get(cmd.stream_name)
|
||||||
if not stream:
|
if not stream:
|
||||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# We're about to go and catch up with the stream, so mark as connecting
|
# We're about to go and catch up with the stream, so mark as connecting
|
||||||
# to stop RDATA being handled at the same time.
|
# to stop RDATA being handled at the same time by removing stream from
|
||||||
self.streams_connecting.add(cmd.stream_name)
|
# list of connected streams. We also clear any batched up RDATA from
|
||||||
|
# before we got the POSITION.
|
||||||
|
self._streams_connected.discard(cmd.stream_name)
|
||||||
|
self._pending_batches.clear()
|
||||||
|
|
||||||
# We protect catching up with a linearizer in case the replicaiton
|
# We protect catching up with a linearizer in case the replicaiton
|
||||||
# connection reconnects under us.
|
# connection reconnects under us.
|
||||||
with await self._position_linearizer.queue(cmd.stream_name):
|
with await self._position_linearizer.queue(cmd.stream_name):
|
||||||
# Find where we previously streamed up to.
|
# Find where we previously streamed up to.
|
||||||
current_token = self.replication_data_handler.get_streams_to_replicate().get(
|
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
||||||
cmd.stream_name
|
cmd.stream_name
|
||||||
)
|
)
|
||||||
if current_token is None:
|
if current_token is None:
|
||||||
|
@ -153,32 +154,62 @@ class ReplicationCommandHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# We've now caught up to position sent to us, notify handler.
|
# We've now caught up to position sent to us, notify handler.
|
||||||
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||||
|
|
||||||
self.streams_connecting.discard(cmd.stream_name)
|
self._streams_connected.add(cmd.stream_name)
|
||||||
|
|
||||||
# Handle any RDATA that came in while we were catching up.
|
# Handle any RDATA that came in while we were catching up.
|
||||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
rows = self._pending_batches.pop(cmd.stream_name, [])
|
||||||
if rows:
|
if rows:
|
||||||
await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
# We need to make sure we filter out RDATA rows with a token less
|
||||||
|
# than what we've caught up to. This is slightly fiddly because of
|
||||||
|
# "batched" rows which have a `None` token, indicating that they
|
||||||
|
# have the same token as the next row with a non-None token.
|
||||||
|
#
|
||||||
|
# We do this by walking the list backwards, first removing any RDATA
|
||||||
|
# rows that are part of an uncompeted batch, then taking rows while
|
||||||
|
# their token is either None or greater than where we've caught up
|
||||||
|
# to.
|
||||||
|
uncompleted_batch = []
|
||||||
|
unfinished_batch = True
|
||||||
|
filtered_rows = []
|
||||||
|
for row in reversed(rows):
|
||||||
|
if row.token is not None:
|
||||||
|
unfinished_batch = False
|
||||||
|
if cmd.token < row.token:
|
||||||
|
filtered_rows.append(row)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
elif unfinished_batch:
|
||||||
|
uncompleted_batch.append(row)
|
||||||
|
else:
|
||||||
|
filtered_rows.append(row)
|
||||||
|
|
||||||
|
filtered_rows.reverse()
|
||||||
|
uncompleted_batch.reverse()
|
||||||
|
if uncompleted_batch:
|
||||||
|
self._pending_batches[cmd.stream_name] = uncompleted_batch
|
||||||
|
|
||||||
|
await self.on_rdata(cmd.stream_name, rows[-1].token, filtered_rows)
|
||||||
|
|
||||||
async def on_SYNC(self, cmd: SyncCommand):
|
async def on_SYNC(self, cmd: SyncCommand):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
|
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||||
|
|
||||||
def get_currently_syncing_users(self):
|
def get_currently_syncing_users(self):
|
||||||
"""Get the list of currently syncing users (if any). This is called
|
"""Get the list of currently syncing users (if any). This is called
|
||||||
when a connection has been established and we need to send the
|
when a connection has been established and we need to send the
|
||||||
currently syncing users. (Overriden by the synchrotron's only)
|
currently syncing users.
|
||||||
"""
|
"""
|
||||||
return self.presence_handler.get_currently_syncing_users()
|
return self._presence_handler.get_currently_syncing_users()
|
||||||
|
|
||||||
def update_connection(self, connection):
|
def update_connection(self, connection):
|
||||||
"""Called when a connection has been established (or lost with None).
|
"""Called when a connection has been established (or lost with None).
|
||||||
"""
|
"""
|
||||||
self.connection = connection
|
self._connection = connection
|
||||||
|
|
||||||
def finished_connecting(self):
|
def finished_connecting(self):
|
||||||
"""Called when we have successfully subscribed and caught up to all
|
"""Called when we have successfully subscribed and caught up to all
|
||||||
|
@ -189,15 +220,15 @@ class ReplicationCommandHandler:
|
||||||
# We don't reset the delay any earlier as otherwise if there is a
|
# We don't reset the delay any earlier as otherwise if there is a
|
||||||
# problem during start up we'll end up tight looping connecting to the
|
# problem during start up we'll end up tight looping connecting to the
|
||||||
# server.
|
# server.
|
||||||
if self.factory:
|
if self._factory:
|
||||||
self.factory.resetDelay()
|
self._factory.resetDelay()
|
||||||
|
|
||||||
def send_command(self, cmd: Command):
|
def send_command(self, cmd: Command):
|
||||||
"""Send a command to master (when we get establish a connection if we
|
"""Send a command to master (when we get establish a connection if we
|
||||||
don't have one already.)
|
don't have one already.)
|
||||||
"""
|
"""
|
||||||
if self.connection:
|
if self._connection:
|
||||||
self.connection.send_command(cmd)
|
self._connection.send_command(cmd)
|
||||||
else:
|
else:
|
||||||
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
||||||
|
|
||||||
|
|
|
@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
|
||||||
> ERROR server stopping
|
> ERROR server stopping
|
||||||
* connection closed by server *
|
* connection closed by server *
|
||||||
"""
|
"""
|
||||||
import abc
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, DefaultDict, Dict, List, Set
|
from typing import TYPE_CHECKING, DefaultDict, List
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
|
@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
|
||||||
SyncCommand,
|
SyncCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
MYPY = False
|
if TYPE_CHECKING:
|
||||||
if MYPY:
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.streamer.lost_connection(self)
|
self.streamer.lost_connection(self)
|
||||||
|
|
||||||
|
|
||||||
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
The interface for the handler that should be passed to
|
|
||||||
ClientReplicationStreamProtocol
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_name (str): name of the replication stream for this batch of rows
|
|
||||||
token (int): stream token for this batch of rows
|
|
||||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
|
||||||
Stream.parse_row.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def on_position(self, stream_name, token):
|
|
||||||
"""Called when we get new position data."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def on_sync(self, data):
|
|
||||||
"""Called when get a new SYNC command."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def on_remote_server_up(self, server: str):
|
|
||||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_streams_to_replicate(self):
|
|
||||||
"""Called when a new connection has been established and we need to
|
|
||||||
subscribe to streams.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
map from stream name to the most recent update we have for
|
|
||||||
that stream (ie, the point we want to start replicating from)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_currently_syncing_users(self):
|
|
||||||
"""Get the list of currently syncing users (if any). This is called
|
|
||||||
when a connection has been established and we need to send the
|
|
||||||
currently syncing users."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def update_connection(self, connection):
|
|
||||||
"""Called when a connection has been established (or lost with None).
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def finished_connecting(self):
|
|
||||||
"""Called when we have successfully subscribed and caught up to all
|
|
||||||
streams we're interested in.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||||
|
@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
client_name: str,
|
client_name: str,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
command_handler,
|
command_handler: "ReplicationCommandHandler",
|
||||||
):
|
):
|
||||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||||
|
|
||||||
|
@ -560,17 +493,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.handler = command_handler
|
self.handler = command_handler
|
||||||
|
|
||||||
self.streams = {
|
|
||||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
|
||||||
} # type: Dict[str, Stream]
|
|
||||||
|
|
||||||
# Set of streams that we're currently catching up with.
|
|
||||||
self.streams_connecting = set() # type: Set[str]
|
|
||||||
|
|
||||||
# Map of stream to batched updates. See RdataCommand for info on how
|
|
||||||
# batching works.
|
|
||||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.send_command(NameCommand(self.client_name))
|
self.send_command(NameCommand(self.client_name))
|
||||||
BaseReplicationStreamProtocol.connectionMade(self)
|
BaseReplicationStreamProtocol.connectionMade(self)
|
||||||
|
@ -592,7 +514,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
async def handle_command(self, cmd: Command):
|
async def handle_command(self, cmd: Command):
|
||||||
"""Handle a command we have received over the replication stream.
|
"""Handle a command we have received over the replication stream.
|
||||||
|
|
||||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
Delegates to `command_handler.on_<COMMAND>`, which must return an
|
||||||
|
awaitable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cmd: received command
|
cmd: received command
|
||||||
|
|
|
@ -57,8 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||||
# We now do some gut wrenching so that we have a client that is based
|
# We now do some gut wrenching so that we have a client that is based
|
||||||
# off of the slave store rather than the main store.
|
# off of the slave store rather than the main store.
|
||||||
self.replication_handler = ReplicationCommandHandler(self.hs)
|
self.replication_handler = ReplicationCommandHandler(self.hs)
|
||||||
self.replication_handler.store = self.slaved_store
|
self.replication_handler._replication_data_handler = ReplicationDataHandler(
|
||||||
self.replication_handler.replication_data_handler = ReplicationDataHandler(
|
|
||||||
self.slaved_store
|
self.slaved_store
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
"""Base class for tests of the replication streams"""
|
"""Base class for tests of the replication streams"""
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
self.test_handler = Mock(wraps=TestReplicationClientHandler())
|
self.test_handler = Mock(wraps=TestReplicationDataHandler())
|
||||||
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
@ -75,7 +75,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.pump(0.1)
|
self.pump(0.1)
|
||||||
|
|
||||||
|
|
||||||
class TestReplicationClientHandler:
|
class TestReplicationDataHandler:
|
||||||
|
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.streams = set()
|
self.streams = set()
|
||||||
self._received_rdata_rows = []
|
self._received_rdata_rows = []
|
||||||
|
|
Loading…
Reference in New Issue