Add most of the missing type hints to `synapse.federation`. (#11483)
This skips a few methods which are difficult to type.pull/11497/head
							parent
							
								
									b50e39df57
								
							
						
					
					
						commit
						d2279f471b
					
				|  | @ -0,0 +1 @@ | |||
| Add missing type hints to `synapse.federation`. | ||||
							
								
								
									
										6
									
								
								mypy.ini
								
								
								
								
							
							
						
						
									
										6
									
								
								mypy.ini
								
								
								
								
							|  | @ -158,6 +158,12 @@ disallow_untyped_defs = True | |||
| [mypy-synapse.events.*] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.federation.*] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
| [mypy-synapse.federation.transport.client] | ||||
| disallow_untyped_defs = False | ||||
| 
 | ||||
| [mypy-synapse.handlers.*] | ||||
| disallow_untyped_defs = True | ||||
| 
 | ||||
|  |  | |||
|  | @ -128,7 +128,7 @@ class FederationClient(FederationBase): | |||
|             reset_expiry_on_get=False, | ||||
|         ) | ||||
| 
 | ||||
|     def _clear_tried_cache(self): | ||||
|     def _clear_tried_cache(self) -> None: | ||||
|         """Clear pdu_destination_tried cache""" | ||||
|         now = self._clock.time_msec() | ||||
| 
 | ||||
|  | @ -800,7 +800,7 @@ class FederationClient(FederationBase): | |||
|                 no servers successfully handle the request. | ||||
|         """ | ||||
| 
 | ||||
|         async def send_request(destination) -> SendJoinResult: | ||||
|         async def send_request(destination: str) -> SendJoinResult: | ||||
|             response = await self._do_send_join(room_version, destination, pdu) | ||||
| 
 | ||||
|             # If an event was returned (and expected to be returned): | ||||
|  |  | |||
|  | @ -1,6 +1,6 @@ | |||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # Copyright 2019 Matrix.org Federation C.I.C | ||||
| # Copyright 2019-2021 Matrix.org Federation C.I.C | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -450,7 +450,7 @@ class FederationServer(FederationBase): | |||
|         # require callouts to other servers to fetch missing events), but | ||||
|         # impose a limit to avoid going too crazy with ram/cpu. | ||||
| 
 | ||||
|         async def process_pdus_for_room(room_id: str): | ||||
|         async def process_pdus_for_room(room_id: str) -> None: | ||||
|             with nested_logging_context(room_id): | ||||
|                 logger.debug("Processing PDUs for %s", room_id) | ||||
| 
 | ||||
|  | @ -547,7 +547,7 @@ class FederationServer(FederationBase): | |||
| 
 | ||||
|     async def on_state_ids_request( | ||||
|         self, origin: str, room_id: str, event_id: str | ||||
|     ) -> Tuple[int, Dict[str, Any]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         if not event_id: | ||||
|             raise NotImplementedError("Specify an event") | ||||
| 
 | ||||
|  | @ -567,7 +567,9 @@ class FederationServer(FederationBase): | |||
| 
 | ||||
|         return 200, resp | ||||
| 
 | ||||
|     async def _on_state_ids_request_compute(self, room_id, event_id): | ||||
|     async def _on_state_ids_request_compute( | ||||
|         self, room_id: str, event_id: str | ||||
|     ) -> JsonDict: | ||||
|         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) | ||||
|         auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) | ||||
|         return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2021 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -23,6 +24,7 @@ from typing import Optional, Tuple | |||
| 
 | ||||
| from synapse.federation.units import Transaction | ||||
| from synapse.logging.utils import log_function | ||||
| from synapse.storage.databases.main import DataStore | ||||
| from synapse.types import JsonDict | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -31,7 +33,7 @@ logger = logging.getLogger(__name__) | |||
| class TransactionActions: | ||||
|     """Defines persistence actions that relate to handling Transactions.""" | ||||
| 
 | ||||
|     def __init__(self, datastore): | ||||
|     def __init__(self, datastore: DataStore): | ||||
|         self.store = datastore | ||||
| 
 | ||||
|     @log_function | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2021 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -350,7 +351,7 @@ class BaseFederationRow: | |||
|     TypeId = ""  # Unique string that ids the type. Must be overridden in sub classes. | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_data(data): | ||||
|     def from_data(data: JsonDict) -> "BaseFederationRow": | ||||
|         """Parse the data from the federation stream into a row. | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -359,7 +360,7 @@ class BaseFederationRow: | |||
|         """ | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def to_data(self): | ||||
|     def to_data(self) -> JsonDict: | ||||
|         """Serialize this row to be sent over the federation stream. | ||||
| 
 | ||||
|         Returns: | ||||
|  | @ -368,7 +369,7 @@ class BaseFederationRow: | |||
|         """ | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def add_to_buffer(self, buff): | ||||
|     def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: | ||||
|         """Add this row to the appropriate field in the buffer ready for this | ||||
|         to be sent over federation. | ||||
| 
 | ||||
|  | @ -391,15 +392,15 @@ class PresenceDestinationsRow( | |||
|     TypeId = "pd" | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_data(data): | ||||
|     def from_data(data: JsonDict) -> "PresenceDestinationsRow": | ||||
|         return PresenceDestinationsRow( | ||||
|             state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] | ||||
|         ) | ||||
| 
 | ||||
|     def to_data(self): | ||||
|     def to_data(self) -> JsonDict: | ||||
|         return {"state": self.state.as_dict(), "dests": self.destinations} | ||||
| 
 | ||||
|     def add_to_buffer(self, buff): | ||||
|     def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: | ||||
|         buff.presence_destinations.append((self.state, self.destinations)) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -417,13 +418,13 @@ class KeyedEduRow( | |||
|     TypeId = "k" | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_data(data): | ||||
|     def from_data(data: JsonDict) -> "KeyedEduRow": | ||||
|         return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) | ||||
| 
 | ||||
|     def to_data(self): | ||||
|     def to_data(self) -> JsonDict: | ||||
|         return {"key": self.key, "edu": self.edu.get_internal_dict()} | ||||
| 
 | ||||
|     def add_to_buffer(self, buff): | ||||
|     def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: | ||||
|         buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu | ||||
| 
 | ||||
| 
 | ||||
|  | @ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu | |||
|     TypeId = "e" | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_data(data): | ||||
|     def from_data(data: JsonDict) -> "EduRow": | ||||
|         return EduRow(Edu(**data)) | ||||
| 
 | ||||
|     def to_data(self): | ||||
|     def to_data(self) -> JsonDict: | ||||
|         return self.edu.get_internal_dict() | ||||
| 
 | ||||
|     def add_to_buffer(self, buff): | ||||
|     def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: | ||||
|         buff.edus.setdefault(self.edu.destination, []).append(self.edu) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2019 New Vector Ltd | ||||
| # Copyright 2021 The Matrix.org Foundation C.I.C. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -14,7 +15,8 @@ | |||
| # limitations under the License. | ||||
| import datetime | ||||
| import logging | ||||
| from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple | ||||
| from types import TracebackType | ||||
| from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type | ||||
| 
 | ||||
| import attr | ||||
| from prometheus_client import Counter | ||||
|  | @ -213,7 +215,7 @@ class PerDestinationQueue: | |||
|         self._pending_edus_keyed[(edu.edu_type, key)] = edu | ||||
|         self.attempt_new_transaction() | ||||
| 
 | ||||
|     def send_edu(self, edu) -> None: | ||||
|     def send_edu(self, edu: Edu) -> None: | ||||
|         self._pending_edus.append(edu) | ||||
|         self.attempt_new_transaction() | ||||
| 
 | ||||
|  | @ -701,7 +703,12 @@ class _TransactionQueueManager: | |||
| 
 | ||||
|         return self._pdus, pending_edus | ||||
| 
 | ||||
|     async def __aexit__(self, exc_type, exc, tb): | ||||
|     async def __aexit__( | ||||
|         self, | ||||
|         exc_type: Optional[Type[BaseException]], | ||||
|         exc: Optional[BaseException], | ||||
|         tb: Optional[TracebackType], | ||||
|     ) -> None: | ||||
|         if exc_type is not None: | ||||
|             # Failed to send transaction, so we bail out. | ||||
|             return | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ from typing import ( | |||
|     Callable, | ||||
|     Collection, | ||||
|     Dict, | ||||
|     Generator, | ||||
|     Iterable, | ||||
|     List, | ||||
|     Mapping, | ||||
|  | @ -235,11 +236,16 @@ class TransportLayerClient: | |||
| 
 | ||||
|     @log_function | ||||
|     async def make_query( | ||||
|         self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False | ||||
|     ): | ||||
|         self, | ||||
|         destination: str, | ||||
|         query_type: str, | ||||
|         args: dict, | ||||
|         retry_on_dns_fail: bool, | ||||
|         ignore_backoff: bool = False, | ||||
|     ) -> JsonDict: | ||||
|         path = _create_v1_path("/query/%s", query_type) | ||||
| 
 | ||||
|         content = await self.client.get_json( | ||||
|         return await self.client.get_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             args=args, | ||||
|  | @ -248,8 +254,6 @@ class TransportLayerClient: | |||
|             ignore_backoff=ignore_backoff, | ||||
|         ) | ||||
| 
 | ||||
|         return content | ||||
| 
 | ||||
|     @log_function | ||||
|     async def make_membership_event( | ||||
|         self, | ||||
|  | @ -1317,7 +1321,7 @@ class SendJoinResponse: | |||
| 
 | ||||
| 
 | ||||
| @ijson.coroutine | ||||
| def _event_parser(event_dict: JsonDict): | ||||
| def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: | ||||
|     """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs | ||||
|     to add them to a given dictionary. | ||||
|     """ | ||||
|  | @ -1328,7 +1332,9 @@ def _event_parser(event_dict: JsonDict): | |||
| 
 | ||||
| 
 | ||||
| @ijson.coroutine | ||||
| def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): | ||||
| def _event_list_parser( | ||||
|     room_version: RoomVersion, events: List[EventBase] | ||||
| ) -> Generator[None, JsonDict, None]: | ||||
|     """Helper function for use with `ijson.items_coro` to parse an array of | ||||
|     events and add them to the given list. | ||||
|     """ | ||||
|  |  | |||
|  | @ -302,7 +302,7 @@ def register_servlets( | |||
|     authenticator: Authenticator, | ||||
|     ratelimiter: FederationRateLimiter, | ||||
|     servlet_groups: Optional[Iterable[str]] = None, | ||||
| ): | ||||
| ) -> None: | ||||
|     """Initialize and register servlet classes. | ||||
| 
 | ||||
|     Will by default register all servlets. For custom behaviour, pass in | ||||
|  |  | |||
|  | @ -15,10 +15,13 @@ | |||
| import functools | ||||
| import logging | ||||
| import re | ||||
| from typing import Any, Awaitable, Callable, Optional, Tuple, cast | ||||
| 
 | ||||
| from synapse.api.errors import Codes, FederationDeniedError, SynapseError | ||||
| from synapse.api.urls import FEDERATION_V1_PREFIX | ||||
| from synapse.http.server import HttpServer, ServletCallback | ||||
| from synapse.http.servlet import parse_json_object_from_request | ||||
| from synapse.http.site import SynapseRequest | ||||
| from synapse.logging import opentracing | ||||
| from synapse.logging.context import run_in_background | ||||
| from synapse.logging.opentracing import ( | ||||
|  | @ -29,6 +32,7 @@ from synapse.logging.opentracing import ( | |||
|     whitelisted_homeserver, | ||||
| ) | ||||
| from synapse.server import HomeServer | ||||
| from synapse.types import JsonDict | ||||
| from synapse.util.ratelimitutils import FederationRateLimiter | ||||
| from synapse.util.stringutils import parse_and_validate_server_name | ||||
| 
 | ||||
|  | @ -59,9 +63,11 @@ class Authenticator: | |||
|             self.replication_client = hs.get_tcp_replication() | ||||
| 
 | ||||
|     # A method just so we can pass 'self' as the authenticator to the Servlets | ||||
|     async def authenticate_request(self, request, content): | ||||
|     async def authenticate_request( | ||||
|         self, request: SynapseRequest, content: Optional[JsonDict] | ||||
|     ) -> str: | ||||
|         now = self._clock.time_msec() | ||||
|         json_request = { | ||||
|         json_request: JsonDict = { | ||||
|             "method": request.method.decode("ascii"), | ||||
|             "uri": request.uri.decode("ascii"), | ||||
|             "destination": self.server_name, | ||||
|  | @ -114,7 +120,7 @@ class Authenticator: | |||
| 
 | ||||
|         return origin | ||||
| 
 | ||||
|     async def _reset_retry_timings(self, origin): | ||||
|     async def _reset_retry_timings(self, origin: str) -> None: | ||||
|         try: | ||||
|             logger.info("Marking origin %r as up", origin) | ||||
|             await self.store.set_destination_retry_timings(origin, None, 0, 0) | ||||
|  | @ -133,14 +139,14 @@ class Authenticator: | |||
|             logger.exception("Error resetting retry timings on %s", origin) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_auth_header(header_bytes): | ||||
| def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]: | ||||
|     """Parse an X-Matrix auth header | ||||
| 
 | ||||
|     Args: | ||||
|         header_bytes (bytes): header value | ||||
|         header_bytes: header value | ||||
| 
 | ||||
|     Returns: | ||||
|         Tuple[str, str, str]: origin, key id, signature. | ||||
|         origin, key id, signature. | ||||
| 
 | ||||
|     Raises: | ||||
|         AuthenticationError if the header could not be parsed | ||||
|  | @ -148,9 +154,9 @@ def _parse_auth_header(header_bytes): | |||
|     try: | ||||
|         header_str = header_bytes.decode("utf-8") | ||||
|         params = header_str.split(" ")[1].split(",") | ||||
|         param_dict = dict(kv.split("=") for kv in params) | ||||
|         param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)} | ||||
| 
 | ||||
|         def strip_quotes(value): | ||||
|         def strip_quotes(value: str) -> str: | ||||
|             if value.startswith('"'): | ||||
|                 return value[1:-1] | ||||
|             else: | ||||
|  | @ -233,23 +239,25 @@ class BaseFederationServlet: | |||
|         self.ratelimiter = ratelimiter | ||||
|         self.server_name = server_name | ||||
| 
 | ||||
|     def _wrap(self, func): | ||||
|     def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback: | ||||
|         authenticator = self.authenticator | ||||
|         ratelimiter = self.ratelimiter | ||||
| 
 | ||||
|         @functools.wraps(func) | ||||
|         async def new_func(request, *args, **kwargs): | ||||
|         async def new_func( | ||||
|             request: SynapseRequest, *args: Any, **kwargs: str | ||||
|         ) -> Optional[Tuple[int, Any]]: | ||||
|             """A callback which can be passed to HttpServer.RegisterPaths | ||||
| 
 | ||||
|             Args: | ||||
|                 request (twisted.web.http.Request): | ||||
|                 request: | ||||
|                 *args: unused? | ||||
|                 **kwargs (dict[unicode, unicode]): the dict mapping keys to path | ||||
|                     components as specified in the path match regexp. | ||||
|                 **kwargs: the dict mapping keys to path components as specified | ||||
|                     in the path match regexp. | ||||
| 
 | ||||
|             Returns: | ||||
|                 Tuple[int, object]|None: (response code, response object) as returned by | ||||
|                     the callback method. None if the request has already been handled. | ||||
|                 (response code, response object) as returned by the callback method. | ||||
|                 None if the request has already been handled. | ||||
|             """ | ||||
|             content = None | ||||
|             if request.method in [b"PUT", b"POST"]: | ||||
|  | @ -257,7 +265,9 @@ class BaseFederationServlet: | |||
|                 content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|             try: | ||||
|                 origin = await authenticator.authenticate_request(request, content) | ||||
|                 origin: Optional[str] = await authenticator.authenticate_request( | ||||
|                     request, content | ||||
|                 ) | ||||
|             except NoAuthenticationError: | ||||
|                 origin = None | ||||
|                 if self.REQUIRE_AUTH: | ||||
|  | @ -301,7 +311,7 @@ class BaseFederationServlet: | |||
|                                 "client disconnected before we started processing " | ||||
|                                 "request" | ||||
|                             ) | ||||
|                             return -1, None | ||||
|                             return None | ||||
|                         response = await func( | ||||
|                             origin, content, request.args, *args, **kwargs | ||||
|                         ) | ||||
|  | @ -312,9 +322,9 @@ class BaseFederationServlet: | |||
| 
 | ||||
|             return response | ||||
| 
 | ||||
|         return new_func | ||||
|         return cast(ServletCallback, new_func) | ||||
| 
 | ||||
|     def register(self, server): | ||||
|     def register(self, server: HttpServer) -> None: | ||||
|         pattern = re.compile("^" + self.PREFIX + self.PATH + "$") | ||||
| 
 | ||||
|         for method in ("GET", "PUT", "POST"): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke