diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index b4298d33a6..255e3f8d12 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -95,7 +95,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._joined_host_linearizer = Linearizer("_JoinedHostsCache") self._server_notices_mxid = hs.config.servernotices.server_notices_mxid - self._storage_controllers = hs.get_storage_controllers() if ( self.hs.config.worker.run_background_tasks @@ -1138,19 +1137,20 @@ class RoomMemberWorkerStore(EventsWorkerStore): # inspecting all join memberships in `state`. However, if the `state` is # relatively recent then many of its events are likely to be held in # the current state of the room, which is easily available and likely - # cached. We therefore compute the set of `state` events not in the + # cached. + # + # We therefore compute the set of `state` events not in the # current state and only fetch those. - current_state = await self._storage_controllers.state.get_current_state( - room_id + current_memberships = ( + await self._get_approximate_current_memberships_in_room(room_id) ) unknown_state_events = {} joined_users_in_current_state = [] for (type, state_key), event_id in state.items(): - current_event = current_state.get((type, state_key)) - if current_event is None or current_event.event_id != event_id: + if event_id not in current_memberships: unknown_state_events[type, state_key] = event_id - elif current_event.membership == Membership.JOIN: + elif current_memberships[event_id] == Membership.JOIN: joined_users_in_current_state.append(state_key) joined_user_ids = await self.get_joined_user_ids_from_state( @@ -1169,6 +1169,33 @@ class RoomMemberWorkerStore(EventsWorkerStore): return frozenset(cache.hosts_to_joined_users) + # TODO: this _might_ turn out to need caching, let's see + async def _get_approximate_current_memberships_in_room( + self, room_id: str + ) -> Mapping[str, Optional[str]]: + """Build a map from event id to membership, for all events in the current state. + + The event ids of non-memberships events (e.g. `m.room.power_levels`) are present + in the result, mapped to values of `None`. + + The result is approximate for partially-joined rooms. It is fully accurate + for fully-joined rooms. + """ + + def f(txn: LoggingTransaction) -> List[Tuple[str, str]]: + sql = """ + SELECT event_id, membership + FROM current_state_events + WHERE room_id = ?; + """ + txn.execute(sql, (room_id,)) + return txn.fetchall() # type: ignore[return-value] + + rows = await self.db_pool.runInteraction( + "_get_approimate_current_memberships_in_room", f + ) + return {row[0]: row[1] for row in rows} + @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache()