Thread through instance name to replication client. (#7369)
For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.pull/7394/head
parent
3085cde577
commit
0e719f2398
|
@ -0,0 +1 @@
|
|||
Thread through instance name to replication client.
|
|
@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||
else:
|
||||
self.send_handler = None
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super(GenericWorkerReplicationHandler, self).on_rdata(
|
||||
stream_name, token, rows
|
||||
)
|
||||
await self.process_and_notify(stream_name, token, rows)
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
await self._process_and_notify(stream_name, instance_name, token, rows)
|
||||
|
||||
async def process_and_notify(self, stream_name, token, rows):
|
||||
async def _process_and_notify(self, stream_name, instance_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
await self.send_handler.process_replication_rows(
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import abc
|
||||
import logging
|
||||
import re
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from six import raise_from
|
||||
|
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
|
|||
must call `register` to register the path with the HTTP server.
|
||||
|
||||
Requests can be sent by calling the client returned by `make_client`.
|
||||
Requests are sent to master process by default, but can be sent to other
|
||||
named processes by specifying an `instance_name` keyword argument.
|
||||
|
||||
Attributes:
|
||||
NAME (str): A name for the endpoint, added to the path as well as used
|
||||
|
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
|
|||
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||
)
|
||||
|
||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||
# assert here that sub classes don't try and use the name.
|
||||
assert (
|
||||
"instance_name" not in self.PATH_ARGS
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
assert (
|
||||
"instance_name"
|
||||
not in signature(self.__class__._serialize_payload).parameters
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
|
|||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@defer.inlineCallbacks
|
||||
def send_request(**kwargs):
|
||||
def send_request(instance_name="master", **kwargs):
|
||||
# Currently we only support sending requests to master process.
|
||||
if instance_name != "master":
|
||||
raise Exception("Unknown instance")
|
||||
|
||||
data = yield cls._serialize_payload(**kwargs)
|
||||
|
||||
url_args = [
|
||||
|
|
|
@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
from_token, upto_token
|
||||
self._instance_name, from_token, upto_token
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
|
@ -86,17 +86,19 @@ class ReplicationDataHandler:
|
|||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
instance_name: the instance that wrote the rows.
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
|
|
|
@ -278,19 +278,24 @@ class ReplicationCommandHandler:
|
|||
# Check if this is the last of a batch of updates
|
||||
rows = self._pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.on_rdata(stream_name, cmd.token, rows)
|
||||
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
instance_name: the instance that wrote the rows.
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
await self._replication_data_handler.on_rdata(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
|
@ -325,7 +330,9 @@ class ReplicationCommandHandler:
|
|||
updates,
|
||||
current_token,
|
||||
missing_updates,
|
||||
) = await stream.get_updates_since(current_token, cmd.token)
|
||||
) = await stream.get_updates_since(
|
||||
cmd.instance_name, current_token, cmd.token
|
||||
)
|
||||
|
||||
# TODO: add some tests for this
|
||||
|
||||
|
@ -334,7 +341,10 @@ class ReplicationCommandHandler:
|
|||
|
||||
for token, rows in _batch_updates(updates):
|
||||
await self.on_rdata(
|
||||
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
|
||||
cmd.stream_name,
|
||||
cmd.instance_name,
|
||||
token,
|
||||
[stream.parse_row(row) for row in rows],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||
#
|
||||
# The arguments are:
|
||||
#
|
||||
# * instance_name: the writer of the stream
|
||||
# * from_token: the previous stream token: the starting point for fetching the
|
||||
# updates
|
||||
# * to_token: the new stream token: the point to get updates up to
|
||||
|
@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||
# If there are more updates available, it should set `limited` in the result, and
|
||||
# it will be called again to get the next batch.
|
||||
#
|
||||
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
|
@ -93,6 +94,7 @@ class Stream(object):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
local_instance_name: str,
|
||||
current_token_function: Callable[[], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
|
@ -108,9 +110,11 @@ class Stream(object):
|
|||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
|
||||
Args:
|
||||
local_instance_name: The instance name of the current process
|
||||
current_token_function: callback to get the current token, as above
|
||||
update_function: callback go get stream updates, as above
|
||||
"""
|
||||
self.local_instance_name = local_instance_name
|
||||
self.current_token = current_token_function
|
||||
self.update_function = update_function
|
||||
|
||||
|
@ -135,14 +139,14 @@ class Stream(object):
|
|||
"""
|
||||
current_token = self.current_token()
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.last_token, current_token
|
||||
self.local_instance_name, self.last_token, current_token
|
||||
)
|
||||
self.last_token = current_token
|
||||
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(
|
||||
self, from_token: Token, upto_token: Token
|
||||
self, instance_name: str, from_token: Token, upto_token: Token
|
||||
) -> StreamUpdateResult:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
@ -160,19 +164,19 @@ class Stream(object):
|
|||
return [], upto_token, False
|
||||
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(from_token, upto_token, limit):
|
||||
async def update_function(instance_name, from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
|
@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
|||
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
result = await client(
|
||||
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
|
||||
instance_name=instance_name,
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
)
|
||||
return result["updates"], result["upto_token"], result["limited"]
|
||||
|
||||
|
@ -226,6 +233,7 @@ class BackfillStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_backfill_token,
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
|
@ -261,7 +269,9 @@ class PresenceStream(Stream):
|
|||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(store.get_current_presence_token, update_function)
|
||||
super().__init__(
|
||||
hs.get_instance_name(), store.get_current_presence_token, update_function
|
||||
)
|
||||
|
||||
|
||||
class TypingStream(Stream):
|
||||
|
@ -284,7 +294,9 @@ class TypingStream(Stream):
|
|||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(typing_handler.get_current_token, update_function)
|
||||
super().__init__(
|
||||
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
||||
)
|
||||
|
||||
|
||||
class ReceiptsStream(Stream):
|
||||
|
@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_receipt_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
@ -322,14 +335,16 @@ class PushRulesStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(PushRulesStream, self).__init__(
|
||||
self._current_token, self._update_function
|
||||
hs.get_instance_name(), self._current_token, self._update_function
|
||||
)
|
||||
|
||||
def _current_token(self) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: Token, to_token: Token, limit: int
|
||||
):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
|
||||
limited = False
|
||||
|
@ -356,6 +371,7 @@ class PushersStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_pushers_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
|
@ -387,6 +403,7 @@ class CachesStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
)
|
||||
|
@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_public_room_stream_id,
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_to_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
@ -487,6 +508,7 @@ class AccountDataStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self.store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
@ -517,6 +539,7 @@ class GroupServerStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_group_stream_token,
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
|
|
|
@ -118,11 +118,17 @@ class EventsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
self._store.get_current_events_token, self._update_function,
|
||||
hs.get_instance_name(),
|
||||
self._store.get_current_events_token,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, from_token: Token, current_token: Token, target_row_count: int
|
||||
self,
|
||||
instance_name: str,
|
||||
from_token: Token,
|
||||
current_token: Token,
|
||||
target_row_count: int,
|
||||
) -> StreamUpdateResult:
|
||||
|
||||
# the events stream merges together three separate sources:
|
||||
|
|
|
@ -48,8 +48,8 @@ class FederationStream(Stream):
|
|||
current_token = lambda: 0
|
||||
update_function = self._stub_update_function
|
||||
|
||||
super().__init__(current_token, update_function)
|
||||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(from_token, upto_token, limit):
|
||||
async def _stub_update_function(instance_name, from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
|
|
|
@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
|||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super().on_rdata(stream_name, token, rows)
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
# there should be one RDATA command
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||
|
@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
# We should now have caught up and get the missing data
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(token, 3)
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
|
|
|
@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
|
||||
|
@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0]
|
||||
|
|
Loading…
Reference in New Issue