Factor out common code for persisting fetched auth events (#10896)
* Factor more stuff out of `_get_events_and_persist` It turns out that the event-sorting algorithm in `_get_events_and_persist` is also useful in other circumstances. Here we move the current `_auth_and_persist_fetched_events` to `_auth_and_persist_fetched_events_inner`, and then factor the sorting part out to `_auth_and_persist_fetched_events`. * `_get_remote_auth_chain_for_event`: remove redundant `outlier` assignment `get_event_auth` returns events with the outlier flag already set, so this is redundant (though we need to update a test where `get_event_auth` is mocked). * `_get_remote_auth_chain_for_event`: move existing-event tests earlier Move a couple of tests outside the loop. This is a bit inefficient for now, but a future commit will make it better. It should be functionally identical. * `_get_remote_auth_chain_for_event`: use `_auth_and_persist_fetched_events` We can use the same codepath for persisting the events fetched as part of an auth chain as for those fetched individually by `_get_events_and_persist` for building the state at a backwards extremity. * `_get_remote_auth_chain_for_event`: use a dict for efficiency `_auth_and_persist_fetched_events` sorts the events itself, so we no longer need to care about maintaining the ordering from `get_event_auth` (and no longer need to sort by depth in `get_event_auth`). That means that we can use a map, making it easier to filter out events we already have, etc. * changelog * `_auth_and_persist_fetched_events`: improve docstringpull/10907/head
							parent
							
								
									261c9763c4
								
							
						
					
					
						commit
						85551b7a85
					
				|  | @ -0,0 +1 @@ | |||
|  Clean up some of the federation event authentication code for clarity. | ||||
|  | @ -501,8 +501,6 @@ class FederationClient(FederationBase): | |||
|             destination, auth_chain, outlier=True, room_version=room_version | ||||
|         ) | ||||
| 
 | ||||
|         signed_auth.sort(key=lambda e: e.depth) | ||||
| 
 | ||||
|         return signed_auth | ||||
| 
 | ||||
|     def _is_unknown_endpoint( | ||||
|  |  | |||
|  | @ -1080,7 +1080,7 @@ class FederationEventHandler: | |||
| 
 | ||||
|         room_version = await self._store.get_room_version(room_id) | ||||
| 
 | ||||
|         event_map: Dict[str, EventBase] = {} | ||||
|         events: List[EventBase] = [] | ||||
| 
 | ||||
|         async def get_event(event_id: str) -> None: | ||||
|             with nested_logging_context(event_id): | ||||
|  | @ -1098,8 +1098,7 @@ class FederationEventHandler: | |||
|                             event_id, | ||||
|                         ) | ||||
|                         return | ||||
| 
 | ||||
|                     event_map[event.event_id] = event | ||||
|                     events.append(event) | ||||
| 
 | ||||
|                 except Exception as e: | ||||
|                     logger.warning( | ||||
|  | @ -1110,11 +1109,29 @@ class FederationEventHandler: | |||
|                     ) | ||||
| 
 | ||||
|         await concurrently_execute(get_event, event_ids, 5) | ||||
|         logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids)) | ||||
|         logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) | ||||
|         await self._auth_and_persist_fetched_events(destination, room_id, events) | ||||
| 
 | ||||
|     async def _auth_and_persist_fetched_events( | ||||
|         self, origin: str, room_id: str, events: Iterable[EventBase] | ||||
|     ) -> None: | ||||
|         """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event | ||||
| 
 | ||||
|         The events to be persisted must be outliers. | ||||
| 
 | ||||
|         We first sort the events to make sure that we process each event's auth_events | ||||
|         before the event itself, and then auth and persist them. | ||||
| 
 | ||||
|         Notifies about the events where appropriate. | ||||
| 
 | ||||
|         Params: | ||||
|             origin: where the events came from | ||||
|             room_id: the room that the events are meant to be in (though this has | ||||
|                not yet been checked) | ||||
|             events: the events that have been fetched | ||||
|         """ | ||||
|         event_map = {event.event_id: event for event in events} | ||||
| 
 | ||||
|         # we now need to auth the events in an order which ensures that each event's | ||||
|         # auth_events are authed before the event itself. | ||||
|         # | ||||
|         # XXX: it might be possible to kick this process off in parallel with fetching | ||||
|         # the events. | ||||
|         while event_map: | ||||
|  | @ -1141,22 +1158,18 @@ class FederationEventHandler: | |||
|                 "Persisting %i of %i remaining events", len(roots), len(event_map) | ||||
|             ) | ||||
| 
 | ||||
|             await self._auth_and_persist_fetched_events(destination, room_id, roots) | ||||
|             await self._auth_and_persist_fetched_events_inner(origin, room_id, roots) | ||||
| 
 | ||||
|             for ev in roots: | ||||
|                 del event_map[ev.event_id] | ||||
| 
 | ||||
|     async def _auth_and_persist_fetched_events( | ||||
|     async def _auth_and_persist_fetched_events_inner( | ||||
|         self, origin: str, room_id: str, fetched_events: Collection[EventBase] | ||||
|     ) -> None: | ||||
|         """Persist the events fetched by _get_events_and_persist. | ||||
|         """Helper for _auth_and_persist_fetched_events | ||||
| 
 | ||||
|         The events should not depend on one another, e.g. this should be used to persist | ||||
|         a bunch of outliers, but not a chunk of individual events that depend | ||||
|         on each other for state calculations. | ||||
| 
 | ||||
|         We also assume that all of the auth events for all of the events have already | ||||
|         been persisted. | ||||
|         Persists a batch of events where we have (theoretically) already persisted all | ||||
|         of their auth events. | ||||
| 
 | ||||
|         Notifies about the events where appropriate. | ||||
| 
 | ||||
|  | @ -1164,7 +1177,7 @@ class FederationEventHandler: | |||
|             origin: where the events came from | ||||
|             room_id: the room that the events are meant to be in (though this has | ||||
|                not yet been checked) | ||||
|             event_id: map from event_id -> event for the fetched events | ||||
|             fetched_events: the events to persist | ||||
|         """ | ||||
|         # get all the auth events for all the events in this batch. By now, they should | ||||
|         # have been persisted. | ||||
|  | @ -1558,53 +1571,33 @@ class FederationEventHandler: | |||
|             event_id: the event for which we are lacking auth events | ||||
|         """ | ||||
|         try: | ||||
|             remote_auth_chain = await self._federation_client.get_event_auth( | ||||
|                 destination, room_id, event_id | ||||
|             ) | ||||
|             remote_event_map = { | ||||
|                 e.event_id: e | ||||
|                 for e in await self._federation_client.get_event_auth( | ||||
|                     destination, room_id, event_id | ||||
|                 ) | ||||
|             } | ||||
|         except RequestSendFailed as e1: | ||||
|             # The other side isn't around or doesn't implement the | ||||
|             # endpoint, so lets just bail out. | ||||
|             logger.info("Failed to get event auth from remote: %s", e1) | ||||
|             return | ||||
| 
 | ||||
|         logger.info("/event_auth returned %i events", len(remote_event_map)) | ||||
| 
 | ||||
|         # `event` may be returned, but we should not yet process it. | ||||
|         remote_event_map.pop(event_id, None) | ||||
| 
 | ||||
|         # nor should we reprocess any events we have already seen. | ||||
|         seen_remotes = await self._store.have_seen_events( | ||||
|             room_id, [e.event_id for e in remote_auth_chain] | ||||
|             room_id, remote_event_map.keys() | ||||
|         ) | ||||
|         for s in seen_remotes: | ||||
|             remote_event_map.pop(s, None) | ||||
| 
 | ||||
|         for auth_event in remote_auth_chain: | ||||
|             if auth_event.event_id in seen_remotes: | ||||
|                 continue | ||||
| 
 | ||||
|             if auth_event.event_id == event_id: | ||||
|                 continue | ||||
| 
 | ||||
|             try: | ||||
|                 auth_ids = auth_event.auth_event_ids() | ||||
|                 auth = { | ||||
|                     (e.type, e.state_key): e | ||||
|                     for e in remote_auth_chain | ||||
|                     if e.event_id in auth_ids or e.type == EventTypes.Create | ||||
|                 } | ||||
|                 auth_event.internal_metadata.outlier = True | ||||
| 
 | ||||
|                 logger.debug( | ||||
|                     "_check_event_auth %s missing_auth: %s", | ||||
|                     event_id, | ||||
|                     auth_event.event_id, | ||||
|                 ) | ||||
|                 missing_auth_event_context = EventContext.for_outlier() | ||||
|                 missing_auth_event_context = await self._check_event_auth( | ||||
|                     destination, | ||||
|                     auth_event, | ||||
|                     missing_auth_event_context, | ||||
|                     claimed_auth_event_map=auth, | ||||
|                 ) | ||||
|                 await self.persist_events_and_notify( | ||||
|                     room_id, | ||||
|                     [(auth_event, missing_auth_event_context)], | ||||
|                 ) | ||||
|             except AuthError: | ||||
|                 pass | ||||
|         await self._auth_and_persist_fetched_events( | ||||
|             destination, room_id, remote_event_map.values() | ||||
|         ) | ||||
| 
 | ||||
|     async def _update_context_for_auth_events( | ||||
|         self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] | ||||
|  |  | |||
|  | @ -308,7 +308,12 @@ class FederationTestCase(unittest.HomeserverTestCase): | |||
|         async def get_event_auth( | ||||
|             destination: str, room_id: str, event_id: str | ||||
|         ) -> List[EventBase]: | ||||
|             return auth_events | ||||
|             return [ | ||||
|                 event_from_pdu_json( | ||||
|                     ae.get_pdu_json(), room_version=room_version, outlier=True | ||||
|                 ) | ||||
|                 for ae in auth_events | ||||
|             ] | ||||
| 
 | ||||
|         self.handler.federation_client.get_event_auth = get_event_auth | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff