Make `backfill` and `get_missing_events` use the same codepath (#10645)
Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code.pull/10708/head
							parent
							
								
									40f619eaa5
								
							
						
					
					
						commit
						96715d7633
					
				|  | @ -0,0 +1 @@ | |||
| Make `backfill` and `get_missing_events` use the same codepath. | ||||
|  | @ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event | |||
| from synapse.events import EventBase | ||||
| from synapse.events.snapshot import EventContext | ||||
| from synapse.events.validator import EventValidator | ||||
| from synapse.federation.federation_client import InvalidResponseError | ||||
| from synapse.handlers._base import BaseHandler | ||||
| from synapse.http.servlet import assert_params_in_dict | ||||
| from synapse.logging.context import ( | ||||
|  | @ -116,10 +117,6 @@ class _NewEventInfo: | |||
|     Attributes: | ||||
|         event: the received event | ||||
| 
 | ||||
|         state: the state at that event, according to /state_ids from a remote | ||||
|            homeserver. Only populated for backfilled events which are going to be a | ||||
|            new backwards extremity. | ||||
| 
 | ||||
|         claimed_auth_event_map: a map of (type, state_key) => event for the event's | ||||
|             claimed auth_events. | ||||
| 
 | ||||
|  | @ -134,7 +131,6 @@ class _NewEventInfo: | |||
|     """ | ||||
| 
 | ||||
|     event: EventBase | ||||
|     state: Optional[Sequence[EventBase]] | ||||
|     claimed_auth_event_map: StateMap[EventBase] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -443,113 +439,7 @@ class FederationHandler(BaseHandler): | |||
|             return | ||||
| 
 | ||||
|         logger.info("Got %d prev_events", len(missing_events)) | ||||
|         await self._process_pulled_events(origin, missing_events) | ||||
| 
 | ||||
|     async def _get_state_for_room( | ||||
|         self, | ||||
|         destination: str, | ||||
|         room_id: str, | ||||
|         event_id: str, | ||||
|     ) -> List[EventBase]: | ||||
|         """Requests all of the room state at a given event from a remote | ||||
|         homeserver. | ||||
| 
 | ||||
|         Will also fetch any missing events reported in the `auth_chain_ids` | ||||
|         section of `/state_ids`. | ||||
| 
 | ||||
|         Args: | ||||
|             destination: The remote homeserver to query for the state. | ||||
|             room_id: The id of the room we're interested in. | ||||
|             event_id: The id of the event we want the state at. | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of events in the state, not including the event itself. | ||||
|         """ | ||||
|         ( | ||||
|             state_event_ids, | ||||
|             auth_event_ids, | ||||
|         ) = await self.federation_client.get_room_state_ids( | ||||
|             destination, room_id, event_id=event_id | ||||
|         ) | ||||
| 
 | ||||
|         # Fetch the state events from the DB, and check we have the auth events. | ||||
|         event_map = await self.store.get_events(state_event_ids, allow_rejected=True) | ||||
|         auth_events_in_store = await self.store.have_seen_events( | ||||
|             room_id, auth_event_ids | ||||
|         ) | ||||
| 
 | ||||
|         # Check for missing events. We handle state and auth event seperately, | ||||
|         # as we want to pull the state from the DB, but we don't for the auth | ||||
|         # events. (Note: we likely won't use the majority of the auth chain, and | ||||
|         # it can be *huge* for large rooms, so it's worth ensuring that we don't | ||||
|         # unnecessarily pull it from the DB). | ||||
|         missing_state_events = set(state_event_ids) - set(event_map) | ||||
|         missing_auth_events = set(auth_event_ids) - set(auth_events_in_store) | ||||
|         if missing_state_events or missing_auth_events: | ||||
|             await self._get_events_and_persist( | ||||
|                 destination=destination, | ||||
|                 room_id=room_id, | ||||
|                 events=missing_state_events | missing_auth_events, | ||||
|             ) | ||||
| 
 | ||||
|             if missing_state_events: | ||||
|                 new_events = await self.store.get_events( | ||||
|                     missing_state_events, allow_rejected=True | ||||
|                 ) | ||||
|                 event_map.update(new_events) | ||||
| 
 | ||||
|                 missing_state_events.difference_update(new_events) | ||||
| 
 | ||||
|                 if missing_state_events: | ||||
|                     logger.warning( | ||||
|                         "Failed to fetch missing state events for %s %s", | ||||
|                         event_id, | ||||
|                         missing_state_events, | ||||
|                     ) | ||||
| 
 | ||||
|             if missing_auth_events: | ||||
|                 auth_events_in_store = await self.store.have_seen_events( | ||||
|                     room_id, missing_auth_events | ||||
|                 ) | ||||
|                 missing_auth_events.difference_update(auth_events_in_store) | ||||
| 
 | ||||
|                 if missing_auth_events: | ||||
|                     logger.warning( | ||||
|                         "Failed to fetch missing auth events for %s %s", | ||||
|                         event_id, | ||||
|                         missing_auth_events, | ||||
|                     ) | ||||
| 
 | ||||
|         remote_state = list(event_map.values()) | ||||
| 
 | ||||
|         # check for events which were in the wrong room. | ||||
|         # | ||||
|         # this can happen if a remote server claims that the state or | ||||
|         # auth_events at an event in room A are actually events in room B | ||||
| 
 | ||||
|         bad_events = [ | ||||
|             (event.event_id, event.room_id) | ||||
|             for event in remote_state | ||||
|             if event.room_id != room_id | ||||
|         ] | ||||
| 
 | ||||
|         for bad_event_id, bad_room_id in bad_events: | ||||
|             # This is a bogus situation, but since we may only discover it a long time | ||||
|             # after it happened, we try our best to carry on, by just omitting the | ||||
|             # bad events from the returned auth/state set. | ||||
|             logger.warning( | ||||
|                 "Remote server %s claims event %s in room %s is an auth/state " | ||||
|                 "event in room %s", | ||||
|                 destination, | ||||
|                 bad_event_id, | ||||
|                 bad_room_id, | ||||
|                 room_id, | ||||
|             ) | ||||
| 
 | ||||
|         if bad_events: | ||||
|             remote_state = [e for e in remote_state if e.room_id == room_id] | ||||
| 
 | ||||
|         return remote_state | ||||
|         await self._process_pulled_events(origin, missing_events, backfilled=False) | ||||
| 
 | ||||
|     async def _get_state_after_missing_prev_event( | ||||
|         self, | ||||
|  | @ -567,10 +457,6 @@ class FederationHandler(BaseHandler): | |||
|         Returns: | ||||
|             A list of events in the state, including the event itself | ||||
|         """ | ||||
|         # TODO: This function is basically the same as _get_state_for_room. Can | ||||
|         #   we make backfill() use it, rather than having two code paths? I think the | ||||
|         #   only difference is that backfill() persists the prev events separately. | ||||
| 
 | ||||
|         ( | ||||
|             state_event_ids, | ||||
|             auth_event_ids, | ||||
|  | @ -681,6 +567,7 @@ class FederationHandler(BaseHandler): | |||
|         origin: str, | ||||
|         event: EventBase, | ||||
|         state: Optional[Iterable[EventBase]], | ||||
|         backfilled: bool = False, | ||||
|     ) -> None: | ||||
|         """Called when we have a new pdu. We need to do auth checks and put it | ||||
|         through the StateHandler. | ||||
|  | @ -693,6 +580,9 @@ class FederationHandler(BaseHandler): | |||
|             state: Normally None, but if we are handling a gap in the graph | ||||
|                 (ie, we are missing one or more prev_events), the resolved state at the | ||||
|                 event | ||||
| 
 | ||||
|             backfilled: True if this is part of a historical batch of events (inhibits | ||||
|                 notification to clients, and validation of device keys.) | ||||
|         """ | ||||
|         logger.debug("Processing event: %s", event) | ||||
| 
 | ||||
|  | @ -700,10 +590,15 @@ class FederationHandler(BaseHandler): | |||
|             context = await self.state_handler.compute_event_context( | ||||
|                 event, old_state=state | ||||
|             ) | ||||
|             await self._auth_and_persist_event(origin, event, context, state=state) | ||||
|             await self._auth_and_persist_event( | ||||
|                 origin, event, context, state=state, backfilled=backfilled | ||||
|             ) | ||||
|         except AuthError as e: | ||||
|             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) | ||||
| 
 | ||||
|         if backfilled: | ||||
|             return | ||||
| 
 | ||||
|         # For encrypted messages we check that we know about the sending device, | ||||
|         # if we don't then we mark the device cache for that user as stale. | ||||
|         if event.type == EventTypes.Encrypted: | ||||
|  | @ -868,7 +763,7 @@ class FederationHandler(BaseHandler): | |||
|     @log_function | ||||
|     async def backfill( | ||||
|         self, dest: str, room_id: str, limit: int, extremities: List[str] | ||||
|     ) -> List[EventBase]: | ||||
|     ) -> None: | ||||
|         """Trigger a backfill request to `dest` for the given `room_id` | ||||
| 
 | ||||
|         This will attempt to get more events from the remote. If the other side | ||||
|  | @ -878,6 +773,9 @@ class FederationHandler(BaseHandler): | |||
|         sanity-checking on them. If any of the backfilled events are invalid, | ||||
|         this method throws a SynapseError. | ||||
| 
 | ||||
|         We might also raise an InvalidResponseError if the response from the remote | ||||
|         server is just bogus. | ||||
| 
 | ||||
|         TODO: make this more useful to distinguish failures of the remote | ||||
|         server from invalid events (there is probably no point in trying to | ||||
|         re-fetch invalid events from every other HS in the room.) | ||||
|  | @ -890,111 +788,18 @@ class FederationHandler(BaseHandler): | |||
|         ) | ||||
| 
 | ||||
|         if not events: | ||||
|             return [] | ||||
|             return | ||||
| 
 | ||||
|         # ideally we'd sanity check the events here for excess prev_events etc, | ||||
|         # but it's hard to reject events at this point without completely | ||||
|         # breaking backfill in the same way that it is currently broken by | ||||
|         # events whose signature we cannot verify (#3121). | ||||
|         # | ||||
|         # So for now we accept the events anyway. #3124 tracks this. | ||||
|         # | ||||
|         # for ev in events: | ||||
|         #     self._sanity_check_event(ev) | ||||
| 
 | ||||
|         # Don't bother processing events we already have. | ||||
|         seen_events = await self.store.have_events_in_timeline( | ||||
|             {e.event_id for e in events} | ||||
|         ) | ||||
| 
 | ||||
|         events = [e for e in events if e.event_id not in seen_events] | ||||
| 
 | ||||
|         if not events: | ||||
|             return [] | ||||
| 
 | ||||
|         event_map = {e.event_id: e for e in events} | ||||
| 
 | ||||
|         event_ids = {e.event_id for e in events} | ||||
| 
 | ||||
|         # build a list of events whose prev_events weren't in the batch. | ||||
|         # (XXX: this will include events whose prev_events we already have; that doesn't | ||||
|         # sound right?) | ||||
|         edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] | ||||
| 
 | ||||
|         logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) | ||||
| 
 | ||||
|         # For each edge get the current state. | ||||
| 
 | ||||
|         state_events = {} | ||||
|         events_to_state = {} | ||||
|         for e_id in edges: | ||||
|             state = await self._get_state_for_room( | ||||
|                 destination=dest, | ||||
|                 room_id=room_id, | ||||
|                 event_id=e_id, | ||||
|             ) | ||||
|             state_events.update({s.event_id: s for s in state}) | ||||
|             events_to_state[e_id] = state | ||||
| 
 | ||||
|         required_auth = { | ||||
|             a_id | ||||
|             for event in events + list(state_events.values()) | ||||
|             for a_id in event.auth_event_ids() | ||||
|         } | ||||
|         auth_events = await self.store.get_events(required_auth, allow_rejected=True) | ||||
|         auth_events.update( | ||||
|             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} | ||||
|         ) | ||||
| 
 | ||||
|         ev_infos = [] | ||||
| 
 | ||||
|         # Step 1: persist the events in the chunk we fetched state for (i.e. | ||||
|         # the backwards extremities), with custom auth events and state | ||||
|         for e_id in events_to_state: | ||||
|             # For paranoia we ensure that these events are marked as | ||||
|             # non-outliers | ||||
|             ev = event_map[e_id] | ||||
|             assert not ev.internal_metadata.is_outlier() | ||||
| 
 | ||||
|             ev_infos.append( | ||||
|                 _NewEventInfo( | ||||
|                     event=ev, | ||||
|                     state=events_to_state[e_id], | ||||
|                     claimed_auth_event_map={ | ||||
|                         ( | ||||
|                             auth_events[a_id].type, | ||||
|                             auth_events[a_id].state_key, | ||||
|                         ): auth_events[a_id] | ||||
|                         for a_id in ev.auth_event_ids() | ||||
|                         if a_id in auth_events | ||||
|                     }, | ||||
|         # if there are any events in the wrong room, the remote server is buggy and | ||||
|         # should not be trusted. | ||||
|         for ev in events: | ||||
|             if ev.room_id != room_id: | ||||
|                 raise InvalidResponseError( | ||||
|                     f"Remote server {dest} returned event {ev.event_id} which is in " | ||||
|                     f"room {ev.room_id}, when we were backfilling in {room_id}" | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|         if ev_infos: | ||||
|             await self._auth_and_persist_events( | ||||
|                 dest, room_id, ev_infos, backfilled=True | ||||
|             ) | ||||
| 
 | ||||
|         # Step 2: Persist the rest of the events in the chunk one by one | ||||
|         events.sort(key=lambda e: e.depth) | ||||
| 
 | ||||
|         for event in events: | ||||
|             if event in events_to_state: | ||||
|                 continue | ||||
| 
 | ||||
|             # For paranoia we ensure that these events are marked as | ||||
|             # non-outliers | ||||
|             assert not event.internal_metadata.is_outlier() | ||||
| 
 | ||||
|             context = await self.state_handler.compute_event_context(event) | ||||
| 
 | ||||
|             # We store these one at a time since each event depends on the | ||||
|             # previous to work out the state. | ||||
|             # TODO: We can probably do something more clever here. | ||||
|             await self._auth_and_persist_event(dest, event, context, backfilled=True) | ||||
| 
 | ||||
|         return events | ||||
|         await self._process_pulled_events(dest, events, backfilled=True) | ||||
| 
 | ||||
|     async def maybe_backfill( | ||||
|         self, room_id: str, current_depth: int, limit: int | ||||
|  | @ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler): | |||
|                     # appropriate stuff. | ||||
|                     # TODO: We can probably do something more intelligent here. | ||||
|                     return True | ||||
|                 except SynapseError as e: | ||||
|                 except (SynapseError, InvalidResponseError) as e: | ||||
|                     logger.info("Failed to backfill from %s because %s", dom, e) | ||||
|                     continue | ||||
|                 except HttpResponseException as e: | ||||
|  | @ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler): | |||
|                 else: | ||||
|                     logger.info("Missing auth event %s", auth_event_id) | ||||
| 
 | ||||
|             event_infos.append(_NewEventInfo(event, None, auth)) | ||||
|             event_infos.append(_NewEventInfo(event, auth)) | ||||
| 
 | ||||
|         if event_infos: | ||||
|             await self._auth_and_persist_events( | ||||
|  | @ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler): | |||
|             ) | ||||
| 
 | ||||
|     async def _process_pulled_events( | ||||
|         self, origin: str, events: Iterable[EventBase] | ||||
|         self, origin: str, events: Iterable[EventBase], backfilled: bool | ||||
|     ) -> None: | ||||
|         """Process a batch of events we have pulled from a remote server | ||||
| 
 | ||||
|  | @ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler): | |||
|         Params: | ||||
|             origin: The server we received these events from | ||||
|             events: The received events. | ||||
|             backfilled: True if this is part of a historical batch of events (inhibits | ||||
|                 notification to clients, and validation of device keys.) | ||||
|         """ | ||||
| 
 | ||||
|         # We want to sort these by depth so we process them and | ||||
|  | @ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|         for ev in sorted_events: | ||||
|             with nested_logging_context(ev.event_id): | ||||
|                 await self._process_pulled_event(origin, ev) | ||||
|                 await self._process_pulled_event(origin, ev, backfilled=backfilled) | ||||
| 
 | ||||
|     async def _process_pulled_event(self, origin: str, event: EventBase) -> None: | ||||
|     async def _process_pulled_event( | ||||
|         self, origin: str, event: EventBase, backfilled: bool | ||||
|     ) -> None: | ||||
|         """Process a single event that we have pulled from a remote server | ||||
| 
 | ||||
|         Pulls in any events required to auth the event, persists the received event, | ||||
|  | @ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler): | |||
|         Params: | ||||
|             origin: The server we received this event from | ||||
|             events: The received event | ||||
|             backfilled: True if this is part of a historical batch of events (inhibits | ||||
|                 notification to clients, and validation of device keys.) | ||||
|         """ | ||||
|         logger.info("Processing pulled event %s", event) | ||||
| 
 | ||||
|  | @ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|         try: | ||||
|             state = await self._resolve_state_at_missing_prevs(origin, event) | ||||
|             await self._process_received_pdu(origin, event, state=state) | ||||
|             await self._process_received_pdu( | ||||
|                 origin, event, state=state, backfilled=backfilled | ||||
|             ) | ||||
|         except FederationError as e: | ||||
|             if e.code == 403: | ||||
|                 logger.warning("Pulled event %s failed history check.", event_id) | ||||
|  | @ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler): | |||
|         origin: str, | ||||
|         room_id: str, | ||||
|         event_infos: Collection[_NewEventInfo], | ||||
|         backfilled: bool = False, | ||||
|     ) -> None: | ||||
|         """Creates the appropriate contexts and persists events. The events | ||||
|         should not depend on one another, e.g. this should be used to persist | ||||
|  | @ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler): | |||
|         async def prep(ev_info: _NewEventInfo): | ||||
|             event = ev_info.event | ||||
|             with nested_logging_context(suffix=event.event_id): | ||||
|                 res = await self.state_handler.compute_event_context( | ||||
|                     event, old_state=ev_info.state | ||||
|                 ) | ||||
|                 res = await self.state_handler.compute_event_context(event) | ||||
|                 res = await self._check_event_auth( | ||||
|                     origin, | ||||
|                     event, | ||||
|                     res, | ||||
|                     state=ev_info.state, | ||||
|                     claimed_auth_event_map=ev_info.claimed_auth_event_map, | ||||
|                     backfilled=backfilled, | ||||
|                 ) | ||||
|             return res | ||||
| 
 | ||||
|  | @ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler): | |||
|                 (ev_info.event, context) | ||||
|                 for ev_info, context in zip(event_infos, contexts) | ||||
|             ], | ||||
|             backfilled=backfilled, | ||||
|         ) | ||||
| 
 | ||||
|     async def _persist_auth_tree( | ||||
|  |  | |||
|  | @ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
|                 self._invalidate_cache_and_stream( | ||||
|                     txn, self.have_seen_event, (room_id, event_id) | ||||
|                 ) | ||||
|                 self._invalidate_get_event_cache(event_id) | ||||
| 
 | ||||
|         logger.info("[purge] done") | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff