Move calling http replication out of base stream

pull/7024/head
Erik Johnston 2020-03-24 16:20:05 +00:00
parent e4c5b1d9d6
commit 309aee4636
3 changed files with 60 additions and 47 deletions

View File

@ -16,7 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, List, Optional, Tuple, Union from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union
import attr import attr
@ -40,10 +40,6 @@ class Stream(object):
# The type of the row. Used by the default impl of parse_row. # The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any ROW_TYPE = None # type: Any
# Whether the update function is only available on master. If True then
# calls to get updates are proxied to the master via a HTTP call.
_QUERY_MASTER = False
@classmethod @classmethod
def parse_row(cls, row): def parse_row(cls, row):
"""Parse a row received over replication """Parse a row received over replication
@ -60,10 +56,6 @@ class Stream(object):
return cls.ROW_TYPE(*row) return cls.ROW_TYPE(*row)
def __init__(self, hs): def __init__(self, hs):
self._is_worker = hs.config.worker_app is not None
if self._QUERY_MASTER and self._is_worker:
self._replication_client = ReplicationGetStreamUpdates.make_client(hs)
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token() self.last_token = self.current_token()
@ -110,23 +102,10 @@ class Stream(object):
if from_token == upto_token: if from_token == upto_token:
return [], upto_token, False return [], upto_token, False
if self._is_worker and self._QUERY_MASTER: updates, upto_token, limited = await self.update_function(
result = await self._replication_client( from_token, upto_token, limit=limit,
stream_name=self.NAME, )
from_token=from_token, return updates, upto_token, limited
upto_token=upto_token,
limit=limit,
)
return result["updates"], result["upto_token"], result["limited"]
else:
limited = False
rows = await self.update_function(from_token, upto_token, limit=limit)
updates = [(row[0], row[1:]) for row in rows]
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
def current_token(self): def current_token(self):
"""Gets the current token of the underlying streams. Should be provided """Gets the current token of the underlying streams. Should be provided
@ -148,6 +127,26 @@ class Stream(object):
raise NotImplementedError() raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[int, int, int], Awaitable[List[tuple]]]
) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
class BackfillStream(Stream): class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before """We fetched some old events and either we had never seen that event before
or it went from being an outlier to not. or it went from being an outlier to not.
@ -171,7 +170,7 @@ class BackfillStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs) super(BackfillStream, self).__init__(hs)
@ -192,16 +191,20 @@ class PresenceStream(Stream):
NAME = "presence" NAME = "presence"
ROW_TYPE = PresenceStreamRow ROW_TYPE = PresenceStreamRow
_QUERY_MASTER = True
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
presence_handler = hs.get_presence_handler() presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None: if hs.config.worker_app is None:
self.update_function = presence_handler.get_all_presence_updates # type: ignore self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore
super(PresenceStream, self).__init__(hs) super(PresenceStream, self).__init__(hs)
@ -213,7 +216,6 @@ class TypingStream(Stream):
NAME = "typing" NAME = "typing"
ROW_TYPE = TypingStreamRow ROW_TYPE = TypingStreamRow
_QUERY_MASTER = True
def __init__(self, hs): def __init__(self, hs):
typing_handler = hs.get_typing_handler() typing_handler = hs.get_typing_handler()
@ -221,7 +223,10 @@ class TypingStream(Stream):
self.current_token = typing_handler.get_current_token # type: ignore self.current_token = typing_handler.get_current_token # type: ignore
if hs.config.worker_app is None: if hs.config.worker_app is None:
self.update_function = typing_handler.get_all_typing_updates # type: ignore self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore
super(TypingStream, self).__init__(hs) super(TypingStream, self).__init__(hs)
@ -245,7 +250,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs) super(ReceiptsStream, self).__init__(hs)
@ -269,7 +274,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit): async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], row[2]) for row in rows], to_token, limited
class PushersStream(Stream): class PushersStream(Stream):
@ -288,7 +299,7 @@ class PushersStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs) super(PushersStream, self).__init__(hs)
@ -320,7 +331,7 @@ class CachesStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs) super(CachesStream, self).__init__(hs)
@ -346,7 +357,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs) super(PublicRoomsStream, self).__init__(hs)
@ -367,7 +378,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs) super(DeviceListsStream, self).__init__(hs)
@ -385,7 +396,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs) super(ToDeviceStream, self).__init__(hs)
@ -405,7 +416,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs) super(TagAccountDataStream, self).__init__(hs)
@ -425,10 +436,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs) super(AccountDataStream, self).__init__(hs)
async def update_function(self, from_token, to_token, limit): async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data( global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit from_token, from_token, to_token, limit
) )
@ -455,7 +467,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs) super(GroupServerStream, self).__init__(hs)
@ -473,6 +485,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs) super(UserSignatureStream, self).__init__(hs)

View File

@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr import attr
from ._base import Stream from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream """Handling of the 'events' replication stream
@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs) super(EventsStream, self).__init__(hs)
async def update_function(self, from_token, current_token, limit=None): async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows( event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit from_token, current_token, limit
) )

View File

@ -17,7 +17,7 @@ from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
from synapse.replication.tcp.streams._base import Stream from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream): class FederationStream(Stream):
@ -44,9 +44,9 @@ class FederationStream(Stream):
if hs.config.worker_app is None or hs.should_send_federation(): if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender() federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else: else:
self.current_token = lambda: 0 # type: ignore self.current_token = lambda: 0 # type: ignore
self.update_function = lambda *args, **kwargs: defer.succeed([]) # type: ignore self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs) super(FederationStream, self).__init__(hs)