Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. (#12775)
							parent
							
								
									3d8839c30c
								
							
						
					
					
						commit
						19d79b6ebe
					
				|  | @ -0,0 +1 @@ | |||
| Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. | ||||
|  | @ -288,7 +288,6 @@ class StateHandler: | |||
|         # | ||||
|         # first of all, figure out the state before the event | ||||
|         # | ||||
| 
 | ||||
|         if old_state: | ||||
|             # if we're given the state before the event, then we use that | ||||
|             state_ids_before_event: StateMap[str] = { | ||||
|  | @ -419,33 +418,37 @@ class StateHandler: | |||
|         """ | ||||
|         logger.debug("resolve_state_groups event_ids %s", event_ids) | ||||
| 
 | ||||
|         # map from state group id to the state in that state group (where | ||||
|         # 'state' is a map from state key to event id) | ||||
|         # dict[int, dict[(str, str), str]] | ||||
|         state_groups_ids = await self.state_store.get_state_groups_ids( | ||||
|             room_id, event_ids | ||||
|         ) | ||||
|         state_groups = await self.state_store.get_state_group_for_events(event_ids) | ||||
| 
 | ||||
|         if len(state_groups_ids) == 0: | ||||
|             return _StateCacheEntry(state={}, state_group=None) | ||||
|         elif len(state_groups_ids) == 1: | ||||
|             name, state_list = list(state_groups_ids.items()).pop() | ||||
| 
 | ||||
|             prev_group, delta_ids = await self.state_store.get_state_group_delta(name) | ||||
|         state_group_ids = state_groups.values() | ||||
| 
 | ||||
|         # check if each event has same state group id, if so there's no state to resolve | ||||
|         state_group_ids_set = set(state_group_ids) | ||||
|         if len(state_group_ids_set) == 1: | ||||
|             (state_group_id,) = state_group_ids_set | ||||
|             state = await self.state_store.get_state_for_groups(state_group_ids_set) | ||||
|             prev_group, delta_ids = await self.state_store.get_state_group_delta( | ||||
|                 state_group_id | ||||
|             ) | ||||
|             return _StateCacheEntry( | ||||
|                 state=state_list, | ||||
|                 state_group=name, | ||||
|                 state=state[state_group_id], | ||||
|                 state_group=state_group_id, | ||||
|                 prev_group=prev_group, | ||||
|                 delta_ids=delta_ids, | ||||
|             ) | ||||
|         elif len(state_group_ids_set) == 0: | ||||
|             return _StateCacheEntry(state={}, state_group=None) | ||||
| 
 | ||||
|         room_version = await self.store.get_room_version_id(room_id) | ||||
| 
 | ||||
|         state_to_resolve = await self.state_store.get_state_for_groups( | ||||
|             state_group_ids_set | ||||
|         ) | ||||
| 
 | ||||
|         result = await self._state_resolution_handler.resolve_state_groups( | ||||
|             room_id, | ||||
|             room_version, | ||||
|             state_groups_ids, | ||||
|             state_to_resolve, | ||||
|             None, | ||||
|             state_res_store=StateResolutionStore(self.store), | ||||
|         ) | ||||
|  |  | |||
|  | @ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): | |||
|         group: int, | ||||
|         state_filter: StateFilter, | ||||
|     ) -> Tuple[MutableStateMap[str], bool]: | ||||
|         """Checks if group is in cache. See `_get_state_for_groups` | ||||
|         """Checks if group is in cache. See `get_state_for_groups` | ||||
| 
 | ||||
|         Args: | ||||
|             cache: the state group cache to use | ||||
|  |  | |||
|  | @ -586,7 +586,7 @@ class StateGroupStorage: | |||
|         if not event_ids: | ||||
|             return {} | ||||
| 
 | ||||
|         event_to_groups = await self._get_state_group_for_events(event_ids) | ||||
|         event_to_groups = await self.get_state_group_for_events(event_ids) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = await self.stores.state._get_state_for_groups(groups) | ||||
|  | @ -602,7 +602,7 @@ class StateGroupStorage: | |||
|         Returns: | ||||
|             Resolves to a map of (type, state_key) -> event_id | ||||
|         """ | ||||
|         group_to_state = await self._get_state_for_groups((state_group,)) | ||||
|         group_to_state = await self.get_state_for_groups((state_group,)) | ||||
| 
 | ||||
|         return group_to_state[state_group] | ||||
| 
 | ||||
|  | @ -675,7 +675,7 @@ class StateGroupStorage: | |||
|             RuntimeError if we don't have a state group for one or more of the events | ||||
|                (ie they are outliers or unknown) | ||||
|         """ | ||||
|         event_to_groups = await self._get_state_group_for_events(event_ids) | ||||
|         event_to_groups = await self.get_state_group_for_events(event_ids) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = await self.stores.state._get_state_for_groups( | ||||
|  | @ -716,7 +716,7 @@ class StateGroupStorage: | |||
|             RuntimeError if we don't have a state group for one or more of the events | ||||
|                 (ie they are outliers or unknown) | ||||
|         """ | ||||
|         event_to_groups = await self._get_state_group_for_events(event_ids) | ||||
|         event_to_groups = await self.get_state_group_for_events(event_ids) | ||||
| 
 | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = await self.stores.state._get_state_for_groups( | ||||
|  | @ -774,7 +774,7 @@ class StateGroupStorage: | |||
|         ) | ||||
|         return state_map[event_id] | ||||
| 
 | ||||
|     def _get_state_for_groups( | ||||
|     def get_state_for_groups( | ||||
|         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None | ||||
|     ) -> Awaitable[Dict[int, MutableStateMap[str]]]: | ||||
|         """Gets the state at each of a list of state groups, optionally | ||||
|  | @ -792,7 +792,7 @@ class StateGroupStorage: | |||
|             groups, state_filter or StateFilter.all() | ||||
|         ) | ||||
| 
 | ||||
|     async def _get_state_group_for_events( | ||||
|     async def get_state_group_for_events( | ||||
|         self, | ||||
|         event_ids: Collection[str], | ||||
|         await_full_state: bool = True, | ||||
|  |  | |||
|  | @ -129,6 +129,19 @@ class _DummyStore: | |||
|     async def get_room_version_id(self, room_id): | ||||
|         return RoomVersions.V1.identifier | ||||
| 
 | ||||
|     async def get_state_group_for_events(self, event_ids): | ||||
|         res = {} | ||||
|         for event in event_ids: | ||||
|             res[event] = self._event_to_state_group[event] | ||||
|         return res | ||||
| 
 | ||||
|     async def get_state_for_groups(self, groups): | ||||
|         res = {} | ||||
|         for group in groups: | ||||
|             state = self._group_to_state[group] | ||||
|             res[group] = state | ||||
|         return res | ||||
| 
 | ||||
| 
 | ||||
| class DictObj(dict): | ||||
|     def __init__(self, **kwargs): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Shay
						Shay