Add type hints to the federation handler and server. (#9743)

pull/9765/head
Patrick Cloke 2021-04-06 07:21:57 -04:00 committed by GitHub
parent e7b769aea1
commit d959d28730
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 95 deletions

1
changelog.d/9743.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to federation handler and server.

View File

@ -739,22 +739,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self): def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
): ) -> None:
ret = await self.handler.exchange_third_party_invite( await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed sender_user_id, target_user_id, room_id, signed
) )
return ret
async def on_exchange_third_party_invite_request(self, event_dict: Dict): async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
ret = await self.handler.on_exchange_third_party_invite_request(event_dict) await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def check_server_matches_acl(self, server_name: str, room_id: str): async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room """Check if the given server is allowed by the server ACLs in the room
Args: Args:
@ -878,7 +876,7 @@ class FederationHandlerRegistry:
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation EDU of the given type. federation EDU of the given type.
@ -897,7 +895,7 @@ class FederationHandlerRegistry:
def register_query_handler( def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
): ) -> None:
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation query of the given type. federation query of the given type.
@ -915,15 +913,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str): def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master.""" """Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name] self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
"""Register that the EDU handler is on multiple instances.""" """Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence: if not self.config.use_presence and edu_type == EduTypes.Presence:
return return

View File

@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id): async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(content) await self.handler.on_exchange_third_party_invite_request(content)
return 200, content return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):

View File

@ -21,7 +21,17 @@ import itertools
import logging import logging
from collections.abc import Container from collections.abc import Container
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: async def on_receive_pdu(
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
) -> None:
"""Process a PDU received via a federation /send/ transaction, or """Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
Args: Args:
origin (str): server which initiated the /send/ transaction. Will origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state. be used to fetch missing events or state.
pdu (FrozenEvent): received PDU pdu: received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
""" """
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state) await self._process_received_pdu(origin, pdu, state=state)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None:
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu pdu: received pdu
prevs (set(str)): List of event ids which we are missing prevs: List of event ids which we are missing
min_depth (int): Minimum depth of events to return. min_depth: Minimum depth of events to return.
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
): ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender) logger.exception("Failed to resync device for %s", sender)
@log_function @log_function
async def backfill(self, dest, room_id, limit, extremities): async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id` """Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state """Get joined domains from state
Args: Args:
state (dict[tuple, FrozenEvent]): State map from type/state state: State map from type/state key to event.
key to event.
Returns: Returns:
list[tuple[str, int]]: Returns a list of servers with the Returns a list of servers with the lowest depth of their joins.
lowest depth of their joins. Sorted by lowest depth first. Sorted by lowest depth first.
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
async def try_backfill(domains): async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [
dom
for dom, _ in likely_extremeties_domains
if dom not in tried_domains
]
) )
if success: if success:
return True return True
tried_domains.update(dom for dom, _ in likely_domains) tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False return False
async def _get_events_and_persist( async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str] self, destination: str, room_id: str, events: Iterable[str]
): ) -> None:
"""Fetch the given events from a server, and persist them as outliers. """Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the This function *does not* recursively get missing auth events of the
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos, event_infos,
) )
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev: EventBase) -> None:
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches. or cascade of event fetches.
Args: Args:
ev (synapse.events.EventBase): event to be checked ev: event to be checked
Returns: None
Raises: Raises:
SynapseError if the event does not pass muster SynapseError if the event does not pass muster
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
) )
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event): async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing. """Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue) run_in_background(self._handle_queued_pdus, room_queue)
async def _handle_queued_pdus(self, room_queue): async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
room_queue (list[FrozenEvent, str]): list of PDUs to be processed room_queue: list of PDUs to be processed and the servers that sent them
and the servers that sent them
""" """
for p, origin in room_queue: for p, origin in room_queue:
try: try:
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_join_request(self, origin, pdu): async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and """We have received a join event for a room. Fully process it and
respond with the current state and auth chains. respond with the current state and auth chains.
""" """
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request( async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion self, origin: str, event: EventBase, room_version: RoomVersion
): ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it. """We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event. Respond with the now signed event.
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event return event
async def on_send_leave_request(self, origin, pdu): async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it.""" """ We have received a leave event for a room. Fully process it."""
event = pdu event = pdu
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else: else:
return None return None
async def get_min_depth_for_context(self, context): async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False self,
): origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event( context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled origin, event, state=state, auth_events=auth_events, backfilled=backfilled
) )
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e) logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
):
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
await self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
return ret
async def on_get_missing_events( async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit self,
): origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin) in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth assumes that we have already processed all events in remote_auth
Params: Params:
local_auth (list) local_auth
remote_auth (list) remote_auth
Returns: Returns:
dict dict
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function @log_function
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
): ) -> None:
third_party_invite = {"signed": signed} third_party_invite = {"signed": signed}
event_dict = { event_dict = {
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context) await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite( async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context self,
): room_version: str,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
) -> Tuple[EventBase, EventContext]:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"], event.content["third_party_invite"]["signed"]["token"],
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config) EventValidator().validate_new(event, self.config)
return (event, context) return (event, context)
async def _check_signature(self, event, context): async def _check_signature(self, event: EventBase, context: EventContext) -> None:
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event: The m.room.member event to check
context (EventContext): context:
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception raise last_exception
async def _check_key_revocation(self, public_key, url): async def _check_key_revocation(self, public_key: str, url: str) -> None:
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
Args: Args:
public_key (str): base-64 encoded public key. public_key: base-64 encoded public key.
url (str): Key revocation URL. url: Key revocation URL.
Raises: Raises:
AuthError: if they key has been revoked. AuthError: if they key has been revoked.