diff --git a/changelog.d/9638.misc b/changelog.d/9638.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9638.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f45e7a8c89..7e8e64d61c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,7 +33,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -299,20 +299,23 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - def __init__(self, hs): - typing_handler = hs.get_typing_handler() - + def __init__(self, hs: "HomeServer"): writer_instance = hs.config.worker.writers.typing if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler - update_function = typing_handler.get_all_typing_updates + typing_writer_handler = hs.get_typing_writer_handler() + update_function = ( + typing_writer_handler.get_all_typing_updates + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) + current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(typing_handler.get_current_token), + current_token_without_instance(current_token_function), update_function, ) @@ -509,7 +512,7 @@ class AccountDataStream(Stream): NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5884daea6d..e7a8207eb1 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -49,7 +49,7 @@ from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -846,10 +846,10 @@ class RoomTypingRestServlet(RestServlet): "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() # If we're not on the typing writer instance we should scream if we get @@ -874,16 +874,19 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + try: if content["typing"]: - await self.typing_handler.started_typing( + await typing_handler.started_typing( target_user=target_user, requester=requester, room_id=room_id, timeout=timeout, ) else: - await self.typing_handler.stopped_typing( + await typing_handler.stopped_typing( target_user=target_user, requester=requester, room_id=room_id ) except ShadowBanError: @@ -901,7 +904,7 @@ class RoomAliasListServlet(RestServlet): ), ] - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() diff --git a/synapse/server.py b/synapse/server.py index dd4ee7dd3c..d11d08c573 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -417,9 +417,18 @@ class HomeServer(metaclass=abc.ABCMeta): return PresenceHandler(self) @cache_in_self - def get_typing_handler(self): + def get_typing_writer_handler(self) -> TypingWriterHandler: if self.config.worker.writers.typing == self.get_instance_name(): return TypingWriterHandler(self) + else: + raise Exception("Workers cannot write typing") + + @cache_in_self + def get_typing_handler(self) -> FollowerTypingHandler: + if self.config.worker.writers.typing == self.get_instance_name(): + # Use get_typing_writer_handler to ensure that we use the same + # cached version. + return self.get_typing_writer_handler() else: return FollowerTypingHandler(self)