diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10640.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index afd8f8580a..e1b58d40c5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,9 +1005,7 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu( - origin, event, sent_to_us_directly=True - ) + await self.handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index de86918b7b..246df43501 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -203,18 +203,13 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu( - self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False - ) -> None: - """Process a PDU received via a federation /send/ transaction, or - via backfill of missing prev_events + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction Args: origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu: received PDU - sent_to_us_directly: True if this event was pushed to us; False if - we pulled it as the result of a missing prev_event. """ room_id = pdu.room_id @@ -276,8 +271,6 @@ class FederationHandler(BaseHandler): ) return None - state = None - # Check that the event passes auth based on the state at the event. This is # done for events that are to be added to the timeline (non-outliers). # @@ -285,83 +278,72 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - if sent_to_us_directly: - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired room lock to fetch %d missing prev_events", + len(missing_prevs), + ) + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), - ) - - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - - else: - state = await self._resolve_state_at_missing_prevs(origin, pdu) - - # A second round of checks for all events. Check that the event passes auth - # based on `auth_events`, this allows us to assert that the event would - # have been allowed at some point. If an event passes this check its OK - # for it to be used as part of a returned `/state` request, as either - # a) we received the event as part of the original join and so trust it, or - # b) we'll do a state resolution with existing state before it becomes - # part of the "current state", which adds more protection. - await self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=None) async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int @@ -461,24 +443,7 @@ class FederationHandler(BaseHandler): return logger.info("Got %d prev_events", len(missing_events)) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for ev in missing_events: - logger.info("Handling received prev_event %s", ev) - with nested_logging_context(ev.event_id): - try: - await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) - except FederationError as e: - if e.code == 403: - logger.warning( - "Received prev_event %s failed history check.", - ev.event_id, - ) - else: - raise + await self._process_pulled_events(origin, missing_events) async def _get_state_for_room( self, @@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler): event_infos, ) + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase] + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev) + + async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu(origin, event, state=state) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase ) -> Optional[Iterable[EventBase]]: @@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index 3785799f46..348fcb72a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) - ), + self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), None, ) @@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ), + self.handler.on_receive_pdu("test.serv", lying_event), FederationError, )