Improve type checking in `replication.tcp.Stream` (#7291)

The general idea here is to get rid of the type: ignore annotations on all of the current_token and update_function assignments, which would have caught #7290.

After a bit of experimentation, it seems like the least-awful way to do this is to pass the offending functions in as parameters to the Stream constructor. Unfortunately that means that the concrete implementations no longer have the same constructor signature as Stream itself, which means that it gets hard to correctly annotate STREAMS_MAP.

I've also introduced a couple of new types, to take out some duplication.
pull/7303/head
Richard van der Hoff 2020-04-17 14:49:55 +01:00 committed by GitHub
parent c07fca9e2f
commit 67ff7b8ba0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 122 deletions

1
changelog.d/7291.misc Normal file
View File

@ -0,0 +1 @@
Improve typing annotations in `synapse.replication.tcp.streams.Stream`.

View File

@ -25,8 +25,6 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@ -67,8 +65,7 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
} # type: Dict[str, Type[Stream]]
}
__all__ = [
"STREAMS_MAP",

View File

@ -16,12 +16,11 @@
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -34,8 +33,32 @@ MAX_EVENTS_BEHIND = 500000
# A stream position token
Token = int
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
# The type of a stream update row, after JSON deserialisation, but before
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
#
# It consists of a triplet `(updates, new_last_token, limited)`, where:
# * `updates` is a list of `(token, row)` entries.
# * `new_last_token` is the new position in stream.
# * `limited` is whether there are more updates to fetch.
#
StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# The type of an update_function for a stream
#
# The arguments are:
#
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
# * limit: the maximum number of rows to return
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@ -50,7 +73,7 @@ class Stream(object):
ROW_TYPE = None # type: Any
@classmethod
def parse_row(cls, row):
def parse_row(cls, row: StreamRow):
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@ -64,7 +87,28 @@ class Stream(object):
"""
return cls.ROW_TYPE(*row)
def __init__(self, hs):
def __init__(
self,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream.
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
Args:
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.current_token = current_token_function
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
@ -75,7 +119,7 @@ class Stream(object):
"""
self.last_token = self.current_token()
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
@ -95,7 +139,7 @@ class Stream(object):
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@ -112,33 +156,14 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit=limit,
from_token, upto_token, limit,
)
return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
by the sub classes
Returns:
int
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
that update, and the rest of the tuple gets used to construct
a ``ROW_TYPE`` instance
"""
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
@ -157,9 +182,7 @@ def db_query_to_update_function(
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
@ -168,7 +191,7 @@ def make_http_update_function(
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name,
from_token=from_token,
@ -202,10 +225,10 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
super().__init__(
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
class PresenceStream(Stream):
@ -227,19 +250,18 @@ class PresenceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
# on the master, query the presence handler
presence_handler = hs.get_presence_handler()
update_function = db_query_to_update_function(
presence_handler.get_all_presence_updates
)
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
update_function = make_http_update_function(hs, self.NAME)
super(PresenceStream, self).__init__(hs)
super().__init__(store.get_current_presence_token, update_function)
class TypingStream(Stream):
@ -253,15 +275,16 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
# on the master, query the typing handler
update_function = db_query_to_update_function(
typing_handler.get_all_typing_updates
)
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
update_function = make_http_update_function(hs, self.NAME)
super(TypingStream, self).__init__(hs)
super().__init__(typing_handler.get_current_token, update_function)
class ReceiptsStream(Stream):
@ -281,11 +304,10 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
super().__init__(
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
class PushRulesStream(Stream):
@ -299,13 +321,15 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(hs)
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
)
def current_token(self):
def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
async def update_function(self, from_token, to_token, limit):
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@ -331,10 +355,10 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
super().__init__(
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
class CachesStream(Stream):
@ -362,11 +386,10 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
super().__init__(
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
class PublicRoomsStream(Stream):
@ -388,11 +411,10 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
super().__init__(
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
class DeviceListsStream(Stream):
@ -409,11 +431,10 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
super().__init__(
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
class ToDeviceStream(Stream):
@ -427,11 +448,10 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
super().__init__(
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
class TagAccountDataStream(Stream):
@ -447,11 +467,10 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
super().__init__(
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
class AccountDataStream(Stream):
@ -467,11 +486,10 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
super().__init__(
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
@ -498,11 +516,10 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
super().__init__(
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
class UserSignatureStream(Stream):
@ -516,8 +533,9 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)
super().__init__(
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
)

View File

@ -15,11 +15,11 @@
# limitations under the License.
import heapq
from typing import Tuple, Type
from typing import Iterable, Tuple, Type
import attr
from ._base import Stream, db_query_to_update_function
from ._base import Stream, Token, db_query_to_update_function
"""Handling of the 'events' replication stream
@ -116,12 +116,14 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super().__init__(
self._store.get_current_events_token,
db_query_to_update_function(self._update_function),
)
super(EventsStream, self).__init__(hs)
async def _update_function(self, from_token, current_token, limit=None):
async def _update_function(
self, from_token: Token, current_token: Token, limit: int
) -> Iterable[tuple]:
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)

View File

@ -15,8 +15,6 @@
# limitations under the License.
from collections import namedtuple
from twisted.internet import defer
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
@ -35,7 +33,6 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
@ -43,10 +40,16 @@ class FederationStream(Stream):
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
current_token = federation_sender.get_current_token
update_function = db_query_to_update_function(
federation_sender.get_replication_rows
)
else:
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
current_token = lambda: 0
update_function = self._stub_update_function
super(FederationStream, self).__init__(hs)
super().__init__(current_token, update_function)
@staticmethod
async def _stub_update_function(from_token, upto_token, limit):
return [], upto_token, False