Add `include_event_in_state` to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one before it. This is a bit easier and requires fewer calls to get_events_from_store_or_dest.pull/6524/head
							parent
							
								
									894d2addac
								
							
						
					
					
						commit
						2045356517
					
				|  | @ -0,0 +1 @@ | |||
| Refactor some code in the event authentication path for clarity. | ||||
|  | @ -378,22 +378,10 @@ class FederationHandler(BaseHandler): | |||
|                             ( | ||||
|                                 remote_state, | ||||
|                                 got_auth_chain, | ||||
|                             ) = await self._get_state_for_room(origin, room_id, p) | ||||
| 
 | ||||
|                             # we want the state *after* p; _get_state_for_room returns the | ||||
|                             # state *before* p. | ||||
|                             remote_event = await self.federation_client.get_pdu( | ||||
|                                 [origin], p, room_version, outlier=True | ||||
|                             ) = await self._get_state_for_room( | ||||
|                                 origin, room_id, p, include_event_in_state=True | ||||
|                             ) | ||||
| 
 | ||||
|                             if remote_event is None: | ||||
|                                 raise Exception( | ||||
|                                     "Unable to get missing prev_event %s" % (p,) | ||||
|                                 ) | ||||
| 
 | ||||
|                             if remote_event.is_state(): | ||||
|                                 remote_state.append(remote_event) | ||||
| 
 | ||||
|                             # XXX hrm I'm not convinced that duplicate events will compare | ||||
|                             # for equality, so I'm not sure this does what the author | ||||
|                             # hoped. | ||||
|  | @ -579,20 +567,25 @@ class FederationHandler(BaseHandler): | |||
|                     else: | ||||
|                         raise | ||||
| 
 | ||||
|     @log_function | ||||
|     async def _get_state_for_room( | ||||
|         self, destination: str, room_id: str, event_id: str | ||||
|         self, | ||||
|         destination: str, | ||||
|         room_id: str, | ||||
|         event_id: str, | ||||
|         include_event_in_state: bool = False, | ||||
|     ) -> Tuple[List[EventBase], List[EventBase]]: | ||||
|         """Requests all of the room state at a given event from a remote homeserver. | ||||
| 
 | ||||
|         Args: | ||||
|             destination:: The remote homeserver to query for the state. | ||||
|             destination: The remote homeserver to query for the state. | ||||
|             room_id: The id of the room we're interested in. | ||||
|             event_id: The id of the event we want the state at. | ||||
|             include_event_in_state: if true, the event itself will be included in the | ||||
|                 returned state event list. | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of events in the state, and a list of events in the auth chain | ||||
|             for the given event. | ||||
|             A list of events in the state, possibly including the event itself, and | ||||
|             a list of events in the auth chain for the given event. | ||||
|         """ | ||||
|         ( | ||||
|             state_event_ids, | ||||
|  | @ -602,6 +595,10 @@ class FederationHandler(BaseHandler): | |||
|         ) | ||||
| 
 | ||||
|         desired_events = set(state_event_ids + auth_event_ids) | ||||
| 
 | ||||
|         if include_event_in_state: | ||||
|             desired_events.add(event_id) | ||||
| 
 | ||||
|         event_map = await self._get_events_from_store_or_dest( | ||||
|             destination, room_id, desired_events | ||||
|         ) | ||||
|  | @ -614,12 +611,21 @@ class FederationHandler(BaseHandler): | |||
|                 failed_to_fetch, | ||||
|             ) | ||||
| 
 | ||||
|         pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] | ||||
|         auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] | ||||
|         remote_state = [ | ||||
|             event_map[e_id] for e_id in state_event_ids if e_id in event_map | ||||
|         ] | ||||
| 
 | ||||
|         if include_event_in_state: | ||||
|             remote_event = event_map.get(event_id) | ||||
|             if not remote_event: | ||||
|                 raise Exception("Unable to get missing prev_event %s" % (event_id,)) | ||||
|             if remote_event.is_state(): | ||||
|                 remote_state.append(remote_event) | ||||
| 
 | ||||
|         auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] | ||||
|         auth_chain.sort(key=lambda e: e.depth) | ||||
| 
 | ||||
|         return pdus, auth_chain | ||||
|         return remote_state, auth_chain | ||||
| 
 | ||||
|     async def _get_events_from_store_or_dest( | ||||
|         self, destination: str, room_id: str, event_ids: Iterable[str] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Richard van der Hoff
						Richard van der Hoff