Fix type of `events` in `StateGroupStorage` and `StateHandler` (#12156)

We make multiple passes over this, so a regular iterable won't do.
pull/12159/head
Richard van der Hoff 2022-03-04 10:25:18 +00:00 committed by GitHub
parent 8533c8b03d
commit d56202b038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 7 deletions

1
changelog.d/12156.misc Normal file
View File

@ -0,0 +1 @@
Fix some type annotations.

View File

@ -194,7 +194,7 @@ class StateHandler:
} }
async def get_current_state_ids( async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room """Get the current state, or the state at a set of events, for a room
@ -243,7 +243,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids) return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Iterable[str] self, room_id: str, event_ids: Collection[str]
) -> Set[str]: ) -> Set[str]:
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids
@ -404,7 +404,7 @@ class StateHandler:
@measure_func() @measure_func()
async def resolve_state_groups_for_events( async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str] self, room_id: str, event_ids: Collection[str]
) -> _StateCacheEntry: ) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each """Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.

View File

@ -561,7 +561,7 @@ class StateGroupStorage:
return state_group_delta.prev_group, state_group_delta.delta_ids return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str] self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
@ -596,7 +596,7 @@ class StateGroupStorage:
return group_to_state[state_group] return group_to_state[state_group]
async def get_state_groups( async def get_state_groups(
self, room_id: str, event_ids: Iterable[str] self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]: ) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids """Get the state groups for the given list of event_ids
@ -648,7 +648,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events( async def get_state_for_events(
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@ -684,7 +684,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids