Use `stream.current_token()` and remove `stream_positions()` (#7172)
We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`.pull/7404/head
parent
6b22921b19
commit
3085cde577
|
@ -0,0 +1 @@
|
|||
Use `stream.current_token()` and remove `stream_positions()`.
|
|
@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
|
|||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
# We must update this typing token from the response of the previous
|
||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
||||
# value which we *must* use for the next replication request.
|
||||
return {"typing": self._latest_room_serial}
|
||||
|
||||
def process_replication_rows(self, token, rows):
|
||||
if self._latest_room_serial > token:
|
||||
# The master has gone backwards. To prevent inconsistent data, just
|
||||
|
@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||
)
|
||||
await self.process_and_notify(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
||||
args.update(self.typing_handler.stream_positions())
|
||||
if self.send_handler:
|
||||
args.update(self.send_handler.stream_positions())
|
||||
return args
|
||||
|
||||
async def process_and_notify(self, stream_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
|
@ -799,9 +786,6 @@ class FederationSenderHandler(object):
|
|||
def wake_destination(self, server: str):
|
||||
self.federation_sender.wake_destination(server)
|
||||
|
||||
def stream_positions(self):
|
||||
return {"federation": self.federation_position}
|
||||
|
||||
async def process_replication_rows(self, stream_name, token, rows):
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import six
|
||||
|
||||
|
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
|||
|
||||
self.hs = hs
|
||||
|
||||
def stream_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the current positions of all the streams this store wants to subscribe to
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
|
|
|
@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
|||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
result["user_account_data"] = position
|
||||
result["room_account_data"] = position
|
||||
result["tag_account_data"] = position
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "tag_account_data":
|
||||
self._account_data_id_gen.advance(token)
|
||||
|
|
|
@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
|||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "to_device":
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
|
|
|
@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceStore, self).stream_positions()
|
||||
# The user signature stream uses the same stream ID generator as the
|
||||
# device list stream, so set them both to the device list ID
|
||||
# generator's current token.
|
||||
current_token = self._device_list_id_gen.get_current_token()
|
||||
result[DeviceListsStream.NAME] = current_token
|
||||
result[UserSignatureStream.NAME] = current_token
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
|
|
|
@ -93,12 +93,6 @@ class SlavedEventStore(
|
|||
def get_room_min_stream_ordering(self):
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "events":
|
||||
self._stream_id_gen.advance(token)
|
||||
|
|
|
@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
|||
def get_group_stream_token(self):
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
self._group_updates_id_gen.advance(token)
|
||||
|
|
|
@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
|||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
|
||||
if self.hs.config.use_presence:
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "presence":
|
||||
self._presence_id_gen.advance(token)
|
||||
|
|
|
@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
|||
def get_max_push_rules_stream_id(self):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "push_rules":
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
|
|
|
@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
|||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPusherStore, self).stream_positions()
|
||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
|
|
|
@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
|||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
|
|
|
@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
|||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(RoomStore, self).stream_positions()
|
||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "public_rooms":
|
||||
self._public_room_id_gen.advance(token)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
|
@ -100,23 +100,6 @@ class ReplicationDataHandler:
|
|||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
room_account_data = args.pop("room_account_data", None)
|
||||
if user_account_data:
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
return args
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
|
|
|
@ -314,15 +314,7 @@ class ReplicationCommandHandler:
|
|||
self._pending_batches.pop(cmd.stream_name, [])
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_token is None:
|
||||
logger.warning(
|
||||
"Got POSITION for stream we're not subscribed to: %s",
|
||||
cmd.stream_name,
|
||||
)
|
||||
return
|
||||
current_token = stream.current_token()
|
||||
|
||||
# If the position token matches our current token then we're up to
|
||||
# date and there's nothing to do. Otherwise, fetch all updates
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
|||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.app.generic_worker import (
|
||||
GenericWorkerReplicationHandler,
|
||||
GenericWorkerServer,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = None
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
def reconnect(self):
|
||||
if self._client_transport:
|
||||
|
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
|
||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
||||
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
super().__init__(store)
|
||||
|
||||
# streams to subscribe to: map from stream id to position
|
||||
self.stream_positions = {} # type: Dict[str, int]
|
||||
def __init__(self, hs: HomeServer):
|
||||
super().__init__(hs)
|
||||
|
||||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
return self.stream_positions
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super().on_rdata(stream_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
if (
|
||||
stream_name in self.stream_positions
|
||||
and token > self.stream_positions[stream_name]
|
||||
):
|
||||
self.stream_positions[stream_name] = token
|
||||
|
||||
|
||||
@attr.s()
|
||||
class OneShotRequestFactory:
|
||||
|
|
|
@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.user_tok = self.login("u1", "pass")
|
||||
|
||||
self.reconnect()
|
||||
self.test_handler.stream_positions["events"] = 0
|
||||
|
||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
|
||||
for event in events:
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
|
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# now we should have received all the expected rows in the right order.
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
#
|
||||
# we expect:
|
||||
#
|
||||
|
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
# of the states that got reverted.
|
||||
# - two rows for state2
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
|
||||
# first check the first two rows, which should be state1
|
||||
|
||||
|
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
self.assertGreaterEqual(len(received_rows), len(events))
|
||||
for i in range(NUM_USERS):
|
||||
# for each user, we expect the PL event row, followed by state rows for
|
||||
|
|
|
@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
def test_receipt(self):
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the receipts stream
|
||||
self.test_handler.stream_positions.update({"receipts": 0})
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
self.hs.get_datastore().insert_receipt(
|
||||
|
|
|
@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the typing stream
|
||||
self.test_handler.stream_positions.update({"typing": 0})
|
||||
|
||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
|
Loading…
Reference in New Issue