Persist auth/state events at backwards extremities when we fetch them (#6526)
The main point here is to make sure that the state returned by _get_state_in_room has been authed before we try to use it as state in the room.pull/6527/head
							parent
							
								
									9d173b312c
								
							
						
					
					
						commit
						bc7de87650
					
				|  | @ -0,0 +1 @@ | |||
| Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. | ||||
|  | @ -65,8 +65,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRes | |||
| from synapse.state import StateResolutionStore, resolve_events_with_store | ||||
| from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour | ||||
| from synapse.types import UserID, get_domain_from_id | ||||
| from synapse.util import batch_iter, unwrapFirstError | ||||
| from synapse.util.async_helpers import Linearizer | ||||
| from synapse.util.async_helpers import Linearizer, concurrently_execute | ||||
| from synapse.util.distributor import user_joined_room | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| from synapse.visibility import filter_events_for_server | ||||
|  | @ -238,7 +237,6 @@ class FederationHandler(BaseHandler): | |||
|             return None | ||||
| 
 | ||||
|         state = None | ||||
|         auth_chain = [] | ||||
| 
 | ||||
|         # Get missing pdus if necessary. | ||||
|         if not pdu.internal_metadata.is_outlier(): | ||||
|  | @ -348,7 +346,6 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|                 # Calculate the state after each of the previous events, and | ||||
|                 # resolve them to find the correct state at the current event. | ||||
|                 auth_chains = set() | ||||
|                 event_map = {event_id: pdu} | ||||
|                 try: | ||||
|                     # Get the state of the events we know about | ||||
|  | @ -369,24 +366,14 @@ class FederationHandler(BaseHandler): | |||
|                             "Requesting state at missing prev_event %s", event_id, | ||||
|                         ) | ||||
| 
 | ||||
|                         room_version = await self.store.get_room_version(room_id) | ||||
| 
 | ||||
|                         with nested_logging_context(p): | ||||
|                             # note that if any of the missing prevs share missing state or | ||||
|                             # auth events, the requests to fetch those events are deduped | ||||
|                             # by the get_pdu_cache in federation_client. | ||||
|                             ( | ||||
|                                 remote_state, | ||||
|                                 got_auth_chain, | ||||
|                             ) = await self._get_state_for_room( | ||||
|                             (remote_state, _,) = await self._get_state_for_room( | ||||
|                                 origin, room_id, p, include_event_in_state=True | ||||
|                             ) | ||||
| 
 | ||||
|                             # XXX hrm I'm not convinced that duplicate events will compare | ||||
|                             # for equality, so I'm not sure this does what the author | ||||
|                             # hoped. | ||||
|                             auth_chains.update(got_auth_chain) | ||||
| 
 | ||||
|                             remote_state_map = { | ||||
|                                 (x.type, x.state_key): x.event_id for x in remote_state | ||||
|                             } | ||||
|  | @ -395,6 +382,7 @@ class FederationHandler(BaseHandler): | |||
|                             for x in remote_state: | ||||
|                                 event_map[x.event_id] = x | ||||
| 
 | ||||
|                     room_version = await self.store.get_room_version(room_id) | ||||
|                     state_map = await resolve_events_with_store( | ||||
|                         room_id, | ||||
|                         room_version, | ||||
|  | @ -416,7 +404,6 @@ class FederationHandler(BaseHandler): | |||
|                     event_map.update(evs) | ||||
| 
 | ||||
|                     state = [event_map[e] for e in six.itervalues(state_map)] | ||||
|                     auth_chain = list(auth_chains) | ||||
|                 except Exception: | ||||
|                     logger.warning( | ||||
|                         "[%s %s] Error attempting to resolve state at missing " | ||||
|  | @ -432,9 +419,7 @@ class FederationHandler(BaseHandler): | |||
|                         affected=event_id, | ||||
|                     ) | ||||
| 
 | ||||
|         await self._process_received_pdu( | ||||
|             origin, pdu, state=state, auth_chain=auth_chain | ||||
|         ) | ||||
|         await self._process_received_pdu(origin, pdu, state=state) | ||||
| 
 | ||||
|     async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): | ||||
|         """ | ||||
|  | @ -633,10 +618,7 @@ class FederationHandler(BaseHandler): | |||
|     ) -> Dict[str, EventBase]: | ||||
|         """Fetch events from a remote destination, checking if we already have them. | ||||
| 
 | ||||
|         Args: | ||||
|             destination | ||||
|             room_id | ||||
|             event_ids | ||||
|         Persists any events we don't already have as outliers. | ||||
| 
 | ||||
|         If we fail to fetch any of the events, a warning will be logged, and the event | ||||
|         will be omitted from the result. Likewise, any events which turn out not to | ||||
|  | @ -656,27 +638,15 @@ class FederationHandler(BaseHandler): | |||
|                 room_id, | ||||
|             ) | ||||
| 
 | ||||
|             room_version = await self.store.get_room_version(room_id) | ||||
|             await self._get_events_and_persist( | ||||
|                 destination=destination, room_id=room_id, events=missing_events | ||||
|             ) | ||||
| 
 | ||||
|             # XXX 20 requests at once? really? | ||||
|             for batch in batch_iter(missing_events, 20): | ||||
|                 deferreds = [ | ||||
|                     run_in_background( | ||||
|                         self.federation_client.get_pdu, | ||||
|                         destinations=[destination], | ||||
|                         event_id=e_id, | ||||
|                         room_version=room_version, | ||||
|                     ) | ||||
|                     for e_id in batch | ||||
|                 ] | ||||
| 
 | ||||
|                 res = await make_deferred_yieldable( | ||||
|                     defer.DeferredList(deferreds, consumeErrors=True) | ||||
|                 ) | ||||
| 
 | ||||
|                 for success, result in res: | ||||
|                     if success and result: | ||||
|                         fetched_events[result.event_id] = result | ||||
|             # we need to make sure we re-load from the database to get the rejected | ||||
|             # state correct. | ||||
|             fetched_events.update( | ||||
|                 (await self.store.get_events(missing_events, allow_rejected=True)) | ||||
|             ) | ||||
| 
 | ||||
|         # check for events which were in the wrong room. | ||||
|         # | ||||
|  | @ -705,50 +675,26 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|         return fetched_events | ||||
| 
 | ||||
|     async def _process_received_pdu(self, origin, event, state, auth_chain): | ||||
|     async def _process_received_pdu( | ||||
|         self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], | ||||
|     ): | ||||
|         """ Called when we have a new pdu. We need to do auth checks and put it | ||||
|         through the StateHandler. | ||||
| 
 | ||||
|         Args: | ||||
|             origin: server sending the event | ||||
| 
 | ||||
|             event: event to be persisted | ||||
| 
 | ||||
|             state: Normally None, but if we are handling a gap in the graph | ||||
|                 (ie, we are missing one or more prev_events), the resolved state at the | ||||
|                 event | ||||
|         """ | ||||
|         room_id = event.room_id | ||||
|         event_id = event.event_id | ||||
| 
 | ||||
|         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) | ||||
| 
 | ||||
|         event_ids = set() | ||||
|         if state: | ||||
|             event_ids |= {e.event_id for e in state} | ||||
|         if auth_chain: | ||||
|             event_ids |= {e.event_id for e in auth_chain} | ||||
| 
 | ||||
|         seen_ids = await self.store.have_seen_events(event_ids) | ||||
| 
 | ||||
|         if state and auth_chain is not None: | ||||
|             # If we have any state or auth_chain given to us by the replication | ||||
|             # layer, then we should handle them (if we haven't before.) | ||||
| 
 | ||||
|             event_infos = [] | ||||
| 
 | ||||
|             for e in itertools.chain(auth_chain, state): | ||||
|                 if e.event_id in seen_ids: | ||||
|                     continue | ||||
|                 e.internal_metadata.outlier = True | ||||
|                 auth_ids = e.auth_event_ids() | ||||
|                 auth = { | ||||
|                     (e.type, e.state_key): e | ||||
|                     for e in auth_chain | ||||
|                     if e.event_id in auth_ids or e.type == EventTypes.Create | ||||
|                 } | ||||
|                 event_infos.append(_NewEventInfo(event=e, auth_events=auth)) | ||||
|                 seen_ids.add(e.event_id) | ||||
| 
 | ||||
|             logger.info( | ||||
|                 "[%s %s] persisting newly-received auth/state events %s", | ||||
|                 room_id, | ||||
|                 event_id, | ||||
|                 [e.event.event_id for e in event_infos], | ||||
|             ) | ||||
|             await self._handle_new_events(origin, event_infos) | ||||
| 
 | ||||
|         try: | ||||
|             context = await self._handle_new_event(origin, event, state=state) | ||||
|         except AuthError as e: | ||||
|  | @ -803,8 +749,6 @@ class FederationHandler(BaseHandler): | |||
|         if dest == self.server_name: | ||||
|             raise SynapseError(400, "Can't backfill from self.") | ||||
| 
 | ||||
|         room_version = await self.store.get_room_version(room_id) | ||||
| 
 | ||||
|         events = await self.federation_client.backfill( | ||||
|             dest, room_id, limit=limit, extremities=extremities | ||||
|         ) | ||||
|  | @ -833,6 +777,9 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|         event_ids = set(e.event_id for e in events) | ||||
| 
 | ||||
|         # build a list of events whose prev_events weren't in the batch. | ||||
|         # (XXX: this will include events whose prev_events we already have; that doesn't | ||||
|         # sound right?) | ||||
|         edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] | ||||
| 
 | ||||
|         logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) | ||||
|  | @ -861,95 +808,11 @@ class FederationHandler(BaseHandler): | |||
|         auth_events.update( | ||||
|             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} | ||||
|         ) | ||||
|         missing_auth = required_auth - set(auth_events) | ||||
|         failed_to_fetch = set() | ||||
| 
 | ||||
|         # Try and fetch any missing auth events from both DB and remote servers. | ||||
|         # We repeatedly do this until we stop finding new auth events. | ||||
|         while missing_auth - failed_to_fetch: | ||||
|             logger.info("Missing auth for backfill: %r", missing_auth) | ||||
|             ret_events = await self.store.get_events(missing_auth - failed_to_fetch) | ||||
|             auth_events.update(ret_events) | ||||
| 
 | ||||
|             required_auth.update( | ||||
|                 a_id for event in ret_events.values() for a_id in event.auth_event_ids() | ||||
|             ) | ||||
|             missing_auth = required_auth - set(auth_events) | ||||
| 
 | ||||
|             if missing_auth - failed_to_fetch: | ||||
|                 logger.info( | ||||
|                     "Fetching missing auth for backfill: %r", | ||||
|                     missing_auth - failed_to_fetch, | ||||
|                 ) | ||||
| 
 | ||||
|                 results = await make_deferred_yieldable( | ||||
|                     defer.gatherResults( | ||||
|                         [ | ||||
|                             run_in_background( | ||||
|                                 self.federation_client.get_pdu, | ||||
|                                 [dest], | ||||
|                                 event_id, | ||||
|                                 room_version=room_version, | ||||
|                                 outlier=True, | ||||
|                                 timeout=10000, | ||||
|                             ) | ||||
|                             for event_id in missing_auth - failed_to_fetch | ||||
|                         ], | ||||
|                         consumeErrors=True, | ||||
|                     ) | ||||
|                 ).addErrback(unwrapFirstError) | ||||
|                 auth_events.update({a.event_id: a for a in results if a}) | ||||
|                 required_auth.update( | ||||
|                     a_id | ||||
|                     for event in results | ||||
|                     if event | ||||
|                     for a_id in event.auth_event_ids() | ||||
|                 ) | ||||
|                 missing_auth = required_auth - set(auth_events) | ||||
| 
 | ||||
|                 failed_to_fetch = missing_auth - set(auth_events) | ||||
| 
 | ||||
|         seen_events = await self.store.have_seen_events( | ||||
|             set(auth_events.keys()) | set(state_events.keys()) | ||||
|         ) | ||||
| 
 | ||||
|         # We now have a chunk of events plus associated state and auth chain to | ||||
|         # persist. We do the persistence in two steps: | ||||
|         #   1. Auth events and state get persisted as outliers, plus the | ||||
|         #      backward extremities get persisted (as non-outliers). | ||||
|         #   2. The rest of the events in the chunk get persisted one by one, as | ||||
|         #      each one depends on the previous event for its state. | ||||
|         # | ||||
|         # The important thing is that events in the chunk get persisted as | ||||
|         # non-outliers, including when those events are also in the state or | ||||
|         # auth chain. Caution must therefore be taken to ensure that they are | ||||
|         # not accidentally marked as outliers. | ||||
| 
 | ||||
|         # Step 1a: persist auth events that *don't* appear in the chunk | ||||
|         ev_infos = [] | ||||
|         for a in auth_events.values(): | ||||
|             # We only want to persist auth events as outliers that we haven't | ||||
|             # seen and aren't about to persist as part of the backfilled chunk. | ||||
|             if a.event_id in seen_events or a.event_id in event_map: | ||||
|                 continue | ||||
| 
 | ||||
|             a.internal_metadata.outlier = True | ||||
|             ev_infos.append( | ||||
|                 _NewEventInfo( | ||||
|                     event=a, | ||||
|                     auth_events={ | ||||
|                         ( | ||||
|                             auth_events[a_id].type, | ||||
|                             auth_events[a_id].state_key, | ||||
|                         ): auth_events[a_id] | ||||
|                         for a_id in a.auth_event_ids() | ||||
|                         if a_id in auth_events | ||||
|                     }, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         # Step 1b: persist the events in the chunk we fetched state for (i.e. | ||||
|         # the backwards extremities) as non-outliers. | ||||
|         # Step 1: persist the events in the chunk we fetched state for (i.e. | ||||
|         # the backwards extremities), with custom auth events and state | ||||
|         for e_id in events_to_state: | ||||
|             # For paranoia we ensure that these events are marked as | ||||
|             # non-outliers | ||||
|  | @ -1191,6 +1054,56 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|         return False | ||||
| 
 | ||||
|     async def _get_events_and_persist( | ||||
|         self, destination: str, room_id: str, events: Iterable[str] | ||||
|     ): | ||||
|         """Fetch the given events from a server, and persist them as outliers. | ||||
| 
 | ||||
|         Logs a warning if we can't find the given event. | ||||
|         """ | ||||
| 
 | ||||
|         room_version = await self.store.get_room_version(room_id) | ||||
| 
 | ||||
|         event_infos = [] | ||||
| 
 | ||||
|         async def get_event(event_id: str): | ||||
|             with nested_logging_context(event_id): | ||||
|                 try: | ||||
|                     event = await self.federation_client.get_pdu( | ||||
|                         [destination], event_id, room_version, outlier=True, | ||||
|                     ) | ||||
|                     if event is None: | ||||
|                         logger.warning( | ||||
|                             "Server %s didn't return event %s", destination, event_id, | ||||
|                         ) | ||||
|                         return | ||||
| 
 | ||||
|                     # recursively fetch the auth events for this event | ||||
|                     auth_events = await self._get_events_from_store_or_dest( | ||||
|                         destination, room_id, event.auth_event_ids() | ||||
|                     ) | ||||
|                     auth = {} | ||||
|                     for auth_event_id in event.auth_event_ids(): | ||||
|                         ae = auth_events.get(auth_event_id) | ||||
|                         if ae: | ||||
|                             auth[(ae.type, ae.state_key)] = ae | ||||
| 
 | ||||
|                     event_infos.append(_NewEventInfo(event, None, auth)) | ||||
| 
 | ||||
|                 except Exception as e: | ||||
|                     logger.warning( | ||||
|                         "Error fetching missing state/auth event %s: %s %s", | ||||
|                         event_id, | ||||
|                         type(e), | ||||
|                         e, | ||||
|                     ) | ||||
| 
 | ||||
|         await concurrently_execute(get_event, events, 5) | ||||
| 
 | ||||
|         await self._handle_new_events( | ||||
|             destination, event_infos, | ||||
|         ) | ||||
| 
 | ||||
|     def _sanity_check_event(self, ev): | ||||
|         """ | ||||
|         Do some early sanity checks of a received event | ||||
|  |  | |||
|  | @ -140,8 +140,8 @@ def concurrently_execute(func, args, limit): | |||
| 
 | ||||
|     Args: | ||||
|         func (func): Function to execute, should return a deferred or coroutine. | ||||
|         args (list): List of arguments to pass to func, each invocation of func | ||||
|             gets a signle argument. | ||||
|         args (Iterable): List of arguments to pass to func, each invocation of func | ||||
|             gets a single argument. | ||||
|         limit (int): Maximum number of conccurent executions. | ||||
| 
 | ||||
|     Returns: | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff