1524 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			1524 lines
		
	
	
		
			54 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2014-2016 OpenMarket Ltd
 | 
						|
# Copyright 2018 New Vector Ltd
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
import logging
 | 
						|
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    AbstractSet,
 | 
						|
    Collection,
 | 
						|
    Dict,
 | 
						|
    FrozenSet,
 | 
						|
    Iterable,
 | 
						|
    List,
 | 
						|
    Mapping,
 | 
						|
    Optional,
 | 
						|
    Sequence,
 | 
						|
    Set,
 | 
						|
    Tuple,
 | 
						|
    Union,
 | 
						|
)
 | 
						|
 | 
						|
import attr
 | 
						|
 | 
						|
from synapse.api.constants import EventTypes, Membership
 | 
						|
from synapse.metrics import LaterGauge
 | 
						|
from synapse.metrics.background_process_metrics import wrap_as_background_process
 | 
						|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 | 
						|
from synapse.storage.database import (
 | 
						|
    DatabasePool,
 | 
						|
    LoggingDatabaseConnection,
 | 
						|
    LoggingTransaction,
 | 
						|
)
 | 
						|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 | 
						|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
 | 
						|
from synapse.storage.engines import Sqlite3Engine
 | 
						|
from synapse.storage.roommember import (
 | 
						|
    GetRoomsForUserWithStreamOrdering,
 | 
						|
    MemberSummary,
 | 
						|
    ProfileInfo,
 | 
						|
    RoomsForUser,
 | 
						|
)
 | 
						|
from synapse.types import (
 | 
						|
    JsonDict,
 | 
						|
    PersistedEventPosition,
 | 
						|
    StateMap,
 | 
						|
    StrCollection,
 | 
						|
    get_domain_from_id,
 | 
						|
)
 | 
						|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 | 
						|
from synapse.util.iterutils import batch_iter
 | 
						|
from synapse.util.metrics import Measure
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from synapse.server import HomeServer
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 | 
						|
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 | 
						|
 | 
						|
 | 
						|
@attr.s(frozen=True, slots=True, auto_attribs=True)
 | 
						|
class EventIdMembership:
 | 
						|
    """Returned by `get_membership_from_event_ids`"""
 | 
						|
 | 
						|
    user_id: str
 | 
						|
    membership: str
 | 
						|
 | 
						|
 | 
						|
class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        database: DatabasePool,
 | 
						|
        db_conn: LoggingDatabaseConnection,
 | 
						|
        hs: "HomeServer",
 | 
						|
    ):
 | 
						|
        super().__init__(database, db_conn, hs)
 | 
						|
 | 
						|
        self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
 | 
						|
 | 
						|
        if (
 | 
						|
            self.hs.config.worker.run_background_tasks
 | 
						|
            and self.hs.config.metrics.metrics_flags.known_servers
 | 
						|
        ):
 | 
						|
            self._known_servers_count = 1
 | 
						|
            self.hs.get_clock().looping_call(
 | 
						|
                self._count_known_servers,
 | 
						|
                60 * 1000,
 | 
						|
            )
 | 
						|
            self.hs.get_clock().call_later(
 | 
						|
                1,
 | 
						|
                self._count_known_servers,
 | 
						|
            )
 | 
						|
            LaterGauge(
 | 
						|
                "synapse_federation_known_servers",
 | 
						|
                "",
 | 
						|
                [],
 | 
						|
                lambda: self._known_servers_count,
 | 
						|
            )
 | 
						|
 | 
						|
    @wrap_as_background_process("_count_known_servers")
 | 
						|
    async def _count_known_servers(self) -> int:
 | 
						|
        """
 | 
						|
        Count the servers that this server knows about.
 | 
						|
 | 
						|
        The statistic is stored on the class for the
 | 
						|
        `synapse_federation_known_servers` LaterGauge to collect.
 | 
						|
        """
 | 
						|
 | 
						|
        def _transact(txn: LoggingTransaction) -> int:
 | 
						|
            if isinstance(self.database_engine, Sqlite3Engine):
 | 
						|
                query = """
 | 
						|
                    SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
 | 
						|
                    FROM (
 | 
						|
                        SELECT rm.user_id as user_id, instr(rm.user_id, ':')
 | 
						|
                            AS pos FROM room_memberships as rm
 | 
						|
                        INNER JOIN current_state_events as c ON rm.event_id = c.event_id
 | 
						|
                        WHERE c.type = 'm.room.member'
 | 
						|
                    ) as out
 | 
						|
                """
 | 
						|
            else:
 | 
						|
                query = """
 | 
						|
                    SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
 | 
						|
                    FROM current_state_events
 | 
						|
                    WHERE type = 'm.room.member' AND membership = 'join';
 | 
						|
                """
 | 
						|
            txn.execute(query)
 | 
						|
            return list(txn)[0][0]
 | 
						|
 | 
						|
        count = await self.db_pool.runInteraction("get_known_servers", _transact)
 | 
						|
 | 
						|
        # We always know about ourselves, even if we have nothing in
 | 
						|
        # room_memberships (for example, the server is new).
 | 
						|
        self._known_servers_count = max([count, 1])
 | 
						|
        return self._known_servers_count
 | 
						|
 | 
						|
    @cached(max_entries=100000, iterable=True)
 | 
						|
    async def get_users_in_room(self, room_id: str) -> Sequence[str]:
 | 
						|
        """Returns a list of users in the room.
 | 
						|
 | 
						|
        Will return inaccurate results for rooms with partial state, since the state for
 | 
						|
        the forward extremities of those rooms will exclude most members. We may also
 | 
						|
        calculate room state incorrectly for such rooms and believe that a member is or
 | 
						|
        is not in the room when the opposite is true.
 | 
						|
 | 
						|
        Note: If you only care about users in the room local to the homeserver, use
 | 
						|
        `get_local_users_in_room(...)` instead which will be more performant.
 | 
						|
        """
 | 
						|
        return await self.db_pool.simple_select_onecol(
 | 
						|
            table="current_state_events",
 | 
						|
            keyvalues={
 | 
						|
                "type": EventTypes.Member,
 | 
						|
                "room_id": room_id,
 | 
						|
                "membership": Membership.JOIN,
 | 
						|
            },
 | 
						|
            retcol="state_key",
 | 
						|
            desc="get_users_in_room",
 | 
						|
        )
 | 
						|
 | 
						|
    def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
 | 
						|
        """Returns a list of users in the room."""
 | 
						|
 | 
						|
        return self.db_pool.simple_select_onecol_txn(
 | 
						|
            txn,
 | 
						|
            table="current_state_events",
 | 
						|
            keyvalues={
 | 
						|
                "type": EventTypes.Member,
 | 
						|
                "room_id": room_id,
 | 
						|
                "membership": Membership.JOIN,
 | 
						|
            },
 | 
						|
            retcol="state_key",
 | 
						|
        )
 | 
						|
 | 
						|
    @cached()
 | 
						|
    def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
 | 
						|
        raise NotImplementedError()
 | 
						|
 | 
						|
    @cachedList(
 | 
						|
        cached_method_name="get_user_in_room_with_profile", list_name="user_ids"
 | 
						|
    )
 | 
						|
    async def get_subset_users_in_room_with_profiles(
 | 
						|
        self, room_id: str, user_ids: Collection[str]
 | 
						|
    ) -> Mapping[str, ProfileInfo]:
 | 
						|
        """Get a mapping from user ID to profile information for a list of users
 | 
						|
        in a given room.
 | 
						|
 | 
						|
        The profile information comes directly from this room's `m.room.member`
 | 
						|
        events, and so may be specific to this room rather than part of a user's
 | 
						|
        global profile. To avoid privacy leaks, the profile data should only be
 | 
						|
        revealed to users who are already in this room.
 | 
						|
 | 
						|
        Args:
 | 
						|
            room_id: The ID of the room to retrieve the users of.
 | 
						|
            user_ids: a list of users in the room to run the query for
 | 
						|
 | 
						|
        Returns:
 | 
						|
                A mapping from user ID to ProfileInfo.
 | 
						|
        """
 | 
						|
 | 
						|
        def _get_subset_users_in_room_with_profiles(
 | 
						|
            txn: LoggingTransaction,
 | 
						|
        ) -> Dict[str, ProfileInfo]:
 | 
						|
            clause, ids = make_in_list_sql_clause(
 | 
						|
                self.database_engine, "c.state_key", user_ids
 | 
						|
            )
 | 
						|
 | 
						|
            sql = """
 | 
						|
                SELECT state_key, display_name, avatar_url FROM room_memberships as m
 | 
						|
                INNER JOIN current_state_events as c
 | 
						|
                ON m.event_id = c.event_id
 | 
						|
                AND m.room_id = c.room_id
 | 
						|
                AND m.user_id = c.state_key
 | 
						|
                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
 | 
						|
            """ % (
 | 
						|
                clause,
 | 
						|
            )
 | 
						|
            txn.execute(sql, (room_id, Membership.JOIN, *ids))
 | 
						|
 | 
						|
            return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_subset_users_in_room_with_profiles",
 | 
						|
            _get_subset_users_in_room_with_profiles,
 | 
						|
        )
 | 
						|
 | 
						|
    @cached(max_entries=100000, iterable=True)
 | 
						|
    async def get_users_in_room_with_profiles(
 | 
						|
        self, room_id: str
 | 
						|
    ) -> Mapping[str, ProfileInfo]:
 | 
						|
        """Get a mapping from user ID to profile information for all users in a given room.
 | 
						|
 | 
						|
        The profile information comes directly from this room's `m.room.member`
 | 
						|
        events, and so may be specific to this room rather than part of a user's
 | 
						|
        global profile. To avoid privacy leaks, the profile data should only be
 | 
						|
        revealed to users who are already in this room.
 | 
						|
 | 
						|
        Args:
 | 
						|
            room_id: The ID of the room to retrieve the users of.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A mapping from user ID to ProfileInfo.
 | 
						|
 | 
						|
        Preconditions:
 | 
						|
          - There is full state available for the room (it is not partial-stated).
 | 
						|
        """
 | 
						|
 | 
						|
        def _get_users_in_room_with_profiles(
 | 
						|
            txn: LoggingTransaction,
 | 
						|
        ) -> Dict[str, ProfileInfo]:
 | 
						|
            sql = """
 | 
						|
                SELECT state_key, display_name, avatar_url FROM room_memberships as m
 | 
						|
                INNER JOIN current_state_events as c
 | 
						|
                ON m.event_id = c.event_id
 | 
						|
                AND m.room_id = c.room_id
 | 
						|
                AND m.user_id = c.state_key
 | 
						|
                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
 | 
						|
            """
 | 
						|
            txn.execute(sql, (room_id, Membership.JOIN))
 | 
						|
 | 
						|
            return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_users_in_room_with_profiles",
 | 
						|
            _get_users_in_room_with_profiles,
 | 
						|
        )
 | 
						|
 | 
						|
    @cached(max_entries=100000)
 | 
						|
    async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
 | 
						|
        """Get the details of a room roughly suitable for use by the room
 | 
						|
        summary extension to /sync. Useful when lazy loading room members.
 | 
						|
        Args:
 | 
						|
            room_id: The room ID to query
 | 
						|
        Returns:
 | 
						|
            dict of membership states, pointing to a MemberSummary named tuple.
 | 
						|
        """
 | 
						|
 | 
						|
        def _get_room_summary_txn(
 | 
						|
            txn: LoggingTransaction,
 | 
						|
        ) -> Dict[str, MemberSummary]:
 | 
						|
            # first get counts.
 | 
						|
            # We do this all in one transaction to keep the cache small.
 | 
						|
            # FIXME: get rid of this when we have room_stats
 | 
						|
 | 
						|
            # Note, rejected events will have a null membership field, so
 | 
						|
            # we we manually filter them out.
 | 
						|
            sql = """
 | 
						|
                SELECT count(*), membership FROM current_state_events
 | 
						|
                WHERE type = 'm.room.member' AND room_id = ?
 | 
						|
                    AND membership IS NOT NULL
 | 
						|
                GROUP BY membership
 | 
						|
            """
 | 
						|
 | 
						|
            txn.execute(sql, (room_id,))
 | 
						|
            res: Dict[str, MemberSummary] = {}
 | 
						|
            for count, membership in txn:
 | 
						|
                res.setdefault(membership, MemberSummary([], count))
 | 
						|
 | 
						|
            # we order by membership and then fairly arbitrarily by event_id so
 | 
						|
            # heroes are consistent
 | 
						|
            # Note, rejected events will have a null membership field, so
 | 
						|
            # we we manually filter them out.
 | 
						|
            sql = """
 | 
						|
                SELECT state_key, membership, event_id
 | 
						|
                FROM current_state_events
 | 
						|
                WHERE type = 'm.room.member' AND room_id = ?
 | 
						|
                    AND membership IS NOT NULL
 | 
						|
                ORDER BY
 | 
						|
                    CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
 | 
						|
                    event_id ASC
 | 
						|
                LIMIT ?
 | 
						|
            """
 | 
						|
 | 
						|
            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
 | 
						|
            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
 | 
						|
            for user_id, membership, event_id in txn:
 | 
						|
                summary = res[membership]
 | 
						|
                # we will always have a summary for this membership type at this
 | 
						|
                # point given the summary currently contains the counts.
 | 
						|
                members = summary.members
 | 
						|
                members.append((user_id, event_id))
 | 
						|
 | 
						|
            return res
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_room_summary", _get_room_summary_txn
 | 
						|
        )
 | 
						|
 | 
						|
    @cached()
 | 
						|
    async def get_number_joined_users_in_room(self, room_id: str) -> int:
 | 
						|
        return await self.db_pool.simple_select_one_onecol(
 | 
						|
            table="current_state_events",
 | 
						|
            keyvalues={"room_id": room_id, "membership": Membership.JOIN},
 | 
						|
            retcol="COUNT(*)",
 | 
						|
            desc="get_number_joined_users_in_room",
 | 
						|
        )
 | 
						|
 | 
						|
    @cached()
 | 
						|
    async def get_invited_rooms_for_local_user(
 | 
						|
        self, user_id: str
 | 
						|
    ) -> Sequence[RoomsForUser]:
 | 
						|
        """Get all the rooms the *local* user is invited to.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A list of RoomsForUser.
 | 
						|
        """
 | 
						|
 | 
						|
        return await self.get_rooms_for_local_user_where_membership_is(
 | 
						|
            user_id, [Membership.INVITE]
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_invite_for_local_user_in_room(
 | 
						|
        self, user_id: str, room_id: str
 | 
						|
    ) -> Optional[RoomsForUser]:
 | 
						|
        """Gets the invite for the given *local* user and room.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID to find the invite of.
 | 
						|
            room_id: The room to user was invited to.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Either a RoomsForUser or None if no invite was found.
 | 
						|
        """
 | 
						|
        invites = await self.get_invited_rooms_for_local_user(user_id)
 | 
						|
        for invite in invites:
 | 
						|
            if invite.room_id == room_id:
 | 
						|
                return invite
 | 
						|
        return None
 | 
						|
 | 
						|
    async def get_rooms_for_local_user_where_membership_is(
 | 
						|
        self,
 | 
						|
        user_id: str,
 | 
						|
        membership_list: Collection[str],
 | 
						|
        excluded_rooms: StrCollection = (),
 | 
						|
    ) -> List[RoomsForUser]:
 | 
						|
        """Get all the rooms for this *local* user where the membership for this user
 | 
						|
        matches one in the membership list.
 | 
						|
 | 
						|
        Filters out forgotten rooms.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID.
 | 
						|
            membership_list: A list of synapse.api.constants.Membership
 | 
						|
                values which the user must be in.
 | 
						|
            excluded_rooms: A list of rooms to ignore.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The RoomsForUser that the user matches the membership types.
 | 
						|
        """
 | 
						|
        if not membership_list:
 | 
						|
            return []
 | 
						|
 | 
						|
        rooms = await self.db_pool.runInteraction(
 | 
						|
            "get_rooms_for_local_user_where_membership_is",
 | 
						|
            self._get_rooms_for_local_user_where_membership_is_txn,
 | 
						|
            user_id,
 | 
						|
            membership_list,
 | 
						|
        )
 | 
						|
 | 
						|
        # Now we filter out forgotten and excluded rooms
 | 
						|
        rooms_to_exclude: AbstractSet[str] = set()
 | 
						|
 | 
						|
        # Users can't forget joined/invited rooms, so we skip the check for such look ups.
 | 
						|
        if any(m not in (Membership.JOIN, Membership.INVITE) for m in membership_list):
 | 
						|
            rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id)
 | 
						|
 | 
						|
        if excluded_rooms is not None:
 | 
						|
            # Take a copy to avoid mutating the in-cache set
 | 
						|
            rooms_to_exclude = set(rooms_to_exclude)
 | 
						|
            rooms_to_exclude.update(excluded_rooms)
 | 
						|
 | 
						|
        return [room for room in rooms if room.room_id not in rooms_to_exclude]
 | 
						|
 | 
						|
    def _get_rooms_for_local_user_where_membership_is_txn(
 | 
						|
        self,
 | 
						|
        txn: LoggingTransaction,
 | 
						|
        user_id: str,
 | 
						|
        membership_list: List[str],
 | 
						|
    ) -> List[RoomsForUser]:
 | 
						|
        """Get all the rooms for this *local* user where the membership for this user
 | 
						|
        matches one in the membership list.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID.
 | 
						|
            membership_list: A list of synapse.api.constants.Membership
 | 
						|
                    values which the user must be in.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The RoomsForUser that the user matches the membership types.
 | 
						|
        """
 | 
						|
        # Paranoia check.
 | 
						|
        if not self.hs.is_mine_id(user_id):
 | 
						|
            raise Exception(
 | 
						|
                "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
 | 
						|
                % (user_id,),
 | 
						|
            )
 | 
						|
 | 
						|
        clause, args = make_in_list_sql_clause(
 | 
						|
            self.database_engine, "c.membership", membership_list
 | 
						|
        )
 | 
						|
 | 
						|
        sql = """
 | 
						|
            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
 | 
						|
            FROM local_current_membership AS c
 | 
						|
            INNER JOIN events AS e USING (room_id, event_id)
 | 
						|
            INNER JOIN rooms AS r USING (room_id)
 | 
						|
            WHERE
 | 
						|
                user_id = ?
 | 
						|
                AND %s
 | 
						|
        """ % (
 | 
						|
            clause,
 | 
						|
        )
 | 
						|
 | 
						|
        txn.execute(sql, (user_id, *args))
 | 
						|
        results = [RoomsForUser(*r) for r in txn]
 | 
						|
 | 
						|
        return results
 | 
						|
 | 
						|
    @cached(iterable=True)
 | 
						|
    async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
 | 
						|
        """
 | 
						|
        Retrieves a list of the current roommembers who are local to the server.
 | 
						|
        """
 | 
						|
        return await self.db_pool.simple_select_onecol(
 | 
						|
            table="local_current_membership",
 | 
						|
            keyvalues={"room_id": room_id, "membership": Membership.JOIN},
 | 
						|
            retcol="user_id",
 | 
						|
            desc="get_local_users_in_room",
 | 
						|
        )
 | 
						|
 | 
						|
    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
 | 
						|
        """
 | 
						|
        Check whether a given local user is currently joined to the given room.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A boolean indicating whether the user is currently joined to the room
 | 
						|
 | 
						|
        Raises:
 | 
						|
            Exeption when called with a non-local user to this homeserver
 | 
						|
        """
 | 
						|
        if not self.hs.is_mine_id(user_id):
 | 
						|
            raise Exception(
 | 
						|
                "Cannot call 'check_local_user_in_room' on "
 | 
						|
                "non-local user %s" % (user_id,),
 | 
						|
            )
 | 
						|
 | 
						|
        (
 | 
						|
            membership,
 | 
						|
            member_event_id,
 | 
						|
        ) = await self.get_local_current_membership_for_user_in_room(
 | 
						|
            user_id=user_id,
 | 
						|
            room_id=room_id,
 | 
						|
        )
 | 
						|
 | 
						|
        return membership == Membership.JOIN
 | 
						|
 | 
						|
    async def is_server_notice_room(self, room_id: str) -> bool:
 | 
						|
        """
 | 
						|
        Determines whether the given room is a 'Server Notices' room, used for
 | 
						|
        sending server notices to a user.
 | 
						|
 | 
						|
        This is determined by seeing whether the server notices user is present
 | 
						|
        in the room.
 | 
						|
        """
 | 
						|
        if self._server_notices_mxid is None:
 | 
						|
            return False
 | 
						|
        is_server_notices_room = await self.check_local_user_in_room(
 | 
						|
            user_id=self._server_notices_mxid, room_id=room_id
 | 
						|
        )
 | 
						|
        return is_server_notices_room
 | 
						|
 | 
						|
    async def get_local_current_membership_for_user_in_room(
 | 
						|
        self, user_id: str, room_id: str
 | 
						|
    ) -> Tuple[Optional[str], Optional[str]]:
 | 
						|
        """Retrieve the current local membership state and event ID for a user in a room.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The ID of the user.
 | 
						|
            room_id: The ID of the room.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A tuple of (membership_type, event_id). Both will be None if a
 | 
						|
                room_id/user_id pair is not found.
 | 
						|
        """
 | 
						|
        # Paranoia check.
 | 
						|
        if not self.hs.is_mine_id(user_id):
 | 
						|
            raise Exception(
 | 
						|
                "Cannot call 'get_local_current_membership_for_user_in_room' on "
 | 
						|
                "non-local user %s" % (user_id,),
 | 
						|
            )
 | 
						|
 | 
						|
        results_dict = await self.db_pool.simple_select_one(
 | 
						|
            "local_current_membership",
 | 
						|
            {"room_id": room_id, "user_id": user_id},
 | 
						|
            ("membership", "event_id"),
 | 
						|
            allow_none=True,
 | 
						|
            desc="get_local_current_membership_for_user_in_room",
 | 
						|
        )
 | 
						|
        if not results_dict:
 | 
						|
            return None, None
 | 
						|
 | 
						|
        return results_dict.get("membership"), results_dict.get("event_id")
 | 
						|
 | 
						|
    @cached(max_entries=500000, iterable=True)
 | 
						|
    async def get_rooms_for_user_with_stream_ordering(
 | 
						|
        self, user_id: str
 | 
						|
    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
 | 
						|
        """Returns a set of room_ids the user is currently joined to.
 | 
						|
 | 
						|
        If a remote user only returns rooms this server is currently
 | 
						|
        participating in.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Returns the rooms the user is in currently, along with the stream
 | 
						|
            ordering of the most recent join for that user and room, along with
 | 
						|
            the room version of the room.
 | 
						|
        """
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_rooms_for_user_with_stream_ordering",
 | 
						|
            self._get_rooms_for_user_with_stream_ordering_txn,
 | 
						|
            user_id,
 | 
						|
        )
 | 
						|
 | 
						|
    def _get_rooms_for_user_with_stream_ordering_txn(
 | 
						|
        self, txn: LoggingTransaction, user_id: str
 | 
						|
    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
 | 
						|
        # We use `current_state_events` here and not `local_current_membership`
 | 
						|
        # as a) this gets called with remote users and b) this only gets called
 | 
						|
        # for rooms the server is participating in.
 | 
						|
        sql = """
 | 
						|
            SELECT room_id, e.instance_name, e.stream_ordering
 | 
						|
            FROM current_state_events AS c
 | 
						|
            INNER JOIN events AS e USING (room_id, event_id)
 | 
						|
            WHERE
 | 
						|
                c.type = 'm.room.member'
 | 
						|
                AND c.state_key = ?
 | 
						|
                AND c.membership = ?
 | 
						|
        """
 | 
						|
 | 
						|
        txn.execute(sql, (user_id, Membership.JOIN))
 | 
						|
        return frozenset(
 | 
						|
            GetRoomsForUserWithStreamOrdering(
 | 
						|
                room_id, PersistedEventPosition(instance, stream_id)
 | 
						|
            )
 | 
						|
            for room_id, instance, stream_id in txn
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_users_server_still_shares_room_with(
 | 
						|
        self, user_ids: Collection[str]
 | 
						|
    ) -> Set[str]:
 | 
						|
        """Given a list of users return the set that the server still share a
 | 
						|
        room with.
 | 
						|
        """
 | 
						|
 | 
						|
        if not user_ids:
 | 
						|
            return set()
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_users_server_still_shares_room_with",
 | 
						|
            self.get_users_server_still_shares_room_with_txn,
 | 
						|
            user_ids,
 | 
						|
        )
 | 
						|
 | 
						|
    def get_users_server_still_shares_room_with_txn(
 | 
						|
        self,
 | 
						|
        txn: LoggingTransaction,
 | 
						|
        user_ids: Collection[str],
 | 
						|
    ) -> Set[str]:
 | 
						|
        if not user_ids:
 | 
						|
            return set()
 | 
						|
 | 
						|
        sql = """
 | 
						|
            SELECT state_key FROM current_state_events
 | 
						|
            WHERE
 | 
						|
                type = 'm.room.member'
 | 
						|
                AND membership = 'join'
 | 
						|
                AND %s
 | 
						|
            GROUP BY state_key
 | 
						|
        """
 | 
						|
 | 
						|
        clause, args = make_in_list_sql_clause(
 | 
						|
            self.database_engine, "state_key", user_ids
 | 
						|
        )
 | 
						|
 | 
						|
        txn.execute(sql % (clause,), args)
 | 
						|
 | 
						|
        return {row[0] for row in txn}
 | 
						|
 | 
						|
    @cached(max_entries=500000, iterable=True)
 | 
						|
    async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
 | 
						|
        """Returns a set of room_ids the user is currently joined to.
 | 
						|
 | 
						|
        If a remote user only returns rooms this server is currently
 | 
						|
        participating in.
 | 
						|
        """
 | 
						|
        rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
 | 
						|
            (user_id,),
 | 
						|
            None,
 | 
						|
            update_metrics=False,
 | 
						|
        )
 | 
						|
        if rooms:
 | 
						|
            return frozenset(r.room_id for r in rooms)
 | 
						|
 | 
						|
        room_ids = await self.db_pool.simple_select_onecol(
 | 
						|
            table="current_state_events",
 | 
						|
            keyvalues={
 | 
						|
                "type": EventTypes.Member,
 | 
						|
                "membership": Membership.JOIN,
 | 
						|
                "state_key": user_id,
 | 
						|
            },
 | 
						|
            retcol="room_id",
 | 
						|
            desc="get_rooms_for_user",
 | 
						|
        )
 | 
						|
 | 
						|
        return frozenset(room_ids)
 | 
						|
 | 
						|
    @cachedList(
 | 
						|
        cached_method_name="get_rooms_for_user",
 | 
						|
        list_name="user_ids",
 | 
						|
    )
 | 
						|
    async def _get_rooms_for_users(
 | 
						|
        self, user_ids: Collection[str]
 | 
						|
    ) -> Mapping[str, FrozenSet[str]]:
 | 
						|
        """A batched version of `get_rooms_for_user`.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from user_id to set of rooms that is currently in.
 | 
						|
        """
 | 
						|
 | 
						|
        rows = await self.db_pool.simple_select_many_batch(
 | 
						|
            table="current_state_events",
 | 
						|
            column="state_key",
 | 
						|
            iterable=user_ids,
 | 
						|
            retcols=(
 | 
						|
                "state_key",
 | 
						|
                "room_id",
 | 
						|
            ),
 | 
						|
            keyvalues={
 | 
						|
                "type": EventTypes.Member,
 | 
						|
                "membership": Membership.JOIN,
 | 
						|
            },
 | 
						|
            desc="get_rooms_for_users",
 | 
						|
        )
 | 
						|
 | 
						|
        user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
 | 
						|
 | 
						|
        for row in rows:
 | 
						|
            user_rooms[row["state_key"]].add(row["room_id"])
 | 
						|
 | 
						|
        return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
 | 
						|
 | 
						|
    async def get_rooms_for_users(
 | 
						|
        self, user_ids: Collection[str]
 | 
						|
    ) -> Dict[str, FrozenSet[str]]:
 | 
						|
        """A batched wrapper around `_get_rooms_for_users`, to prevent locking
 | 
						|
        other calls to `get_rooms_for_user` for large user lists.
 | 
						|
        """
 | 
						|
        all_user_rooms: Dict[str, FrozenSet[str]] = {}
 | 
						|
 | 
						|
        # 250 users is pretty arbitrary but the data can be quite large if users
 | 
						|
        # are in many rooms.
 | 
						|
        for batch_user_ids in batch_iter(user_ids, 250):
 | 
						|
            all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids))
 | 
						|
 | 
						|
        return all_user_rooms
 | 
						|
 | 
						|
    @cached(max_entries=10000)
 | 
						|
    async def does_pair_of_users_share_a_room(
 | 
						|
        self, user_id: str, other_user_id: str
 | 
						|
    ) -> bool:
 | 
						|
        raise NotImplementedError()
 | 
						|
 | 
						|
    @cachedList(
 | 
						|
        cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
 | 
						|
    )
 | 
						|
    async def _do_users_share_a_room(
 | 
						|
        self, user_id: str, other_user_ids: Collection[str]
 | 
						|
    ) -> Mapping[str, Optional[bool]]:
 | 
						|
        """Return mapping from user ID to whether they share a room with the
 | 
						|
        given user.
 | 
						|
 | 
						|
        Note: `None` and `False` are equivalent and mean they don't share a
 | 
						|
        room.
 | 
						|
        """
 | 
						|
 | 
						|
        def do_users_share_a_room_txn(
 | 
						|
            txn: LoggingTransaction, user_ids: Collection[str]
 | 
						|
        ) -> Dict[str, bool]:
 | 
						|
            clause, args = make_in_list_sql_clause(
 | 
						|
                self.database_engine, "state_key", user_ids
 | 
						|
            )
 | 
						|
 | 
						|
            # This query works by fetching both the list of rooms for the target
 | 
						|
            # user and the set of other users, and then checking if there is any
 | 
						|
            # overlap.
 | 
						|
            sql = f"""
 | 
						|
                SELECT DISTINCT b.state_key
 | 
						|
                FROM (
 | 
						|
                    SELECT room_id FROM current_state_events
 | 
						|
                    WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
 | 
						|
                ) AS a
 | 
						|
                INNER JOIN (
 | 
						|
                    SELECT room_id, state_key FROM current_state_events
 | 
						|
                    WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
 | 
						|
                ) AS b using (room_id)
 | 
						|
            """
 | 
						|
 | 
						|
            txn.execute(sql, (user_id, *args))
 | 
						|
            return {u: True for u, in txn}
 | 
						|
 | 
						|
        to_return = {}
 | 
						|
        for batch_user_ids in batch_iter(other_user_ids, 1000):
 | 
						|
            res = await self.db_pool.runInteraction(
 | 
						|
                "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
 | 
						|
            )
 | 
						|
            to_return.update(res)
 | 
						|
 | 
						|
        return to_return
 | 
						|
 | 
						|
    async def do_users_share_a_room(
 | 
						|
        self, user_id: str, other_user_ids: Collection[str]
 | 
						|
    ) -> Set[str]:
 | 
						|
        """Return the set of users who share a room with the first users"""
 | 
						|
 | 
						|
        user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
 | 
						|
 | 
						|
        return {u for u, share_room in user_dict.items() if share_room}
 | 
						|
 | 
						|
    async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
 | 
						|
        """Returns the set of users who share a room with `user_id`"""
 | 
						|
        room_ids = await self.get_rooms_for_user(user_id)
 | 
						|
 | 
						|
        user_who_share_room: Set[str] = set()
 | 
						|
        for room_id in room_ids:
 | 
						|
            user_ids = await self.get_users_in_room(room_id)
 | 
						|
            user_who_share_room.update(user_ids)
 | 
						|
 | 
						|
        return user_who_share_room
 | 
						|
 | 
						|
    @cached(cache_context=True, iterable=True)
 | 
						|
    async def get_mutual_rooms_between_users(
 | 
						|
        self, user_ids: FrozenSet[str], cache_context: _CacheContext
 | 
						|
    ) -> FrozenSet[str]:
 | 
						|
        """
 | 
						|
        Returns the set of rooms that all users in `user_ids` share.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_ids: A frozen set of all users to investigate and return
 | 
						|
              overlapping joined rooms for.
 | 
						|
            cache_context
 | 
						|
        """
 | 
						|
        shared_room_ids: Optional[FrozenSet[str]] = None
 | 
						|
        for user_id in user_ids:
 | 
						|
            room_ids = await self.get_rooms_for_user(
 | 
						|
                user_id, on_invalidate=cache_context.invalidate
 | 
						|
            )
 | 
						|
            if shared_room_ids is not None:
 | 
						|
                shared_room_ids &= room_ids
 | 
						|
            else:
 | 
						|
                shared_room_ids = room_ids
 | 
						|
 | 
						|
        return shared_room_ids or frozenset()
 | 
						|
 | 
						|
    async def get_joined_user_ids_from_state(
 | 
						|
        self, room_id: str, state: StateMap[str]
 | 
						|
    ) -> Set[str]:
 | 
						|
        """
 | 
						|
        For a given set of state IDs, get a set of user IDs in the room.
 | 
						|
 | 
						|
        This method checks the local event cache, before calling
 | 
						|
        `_get_user_ids_from_membership_event_ids` for any uncached events.
 | 
						|
        """
 | 
						|
 | 
						|
        with Measure(self._clock, "get_joined_user_ids_from_state"):
 | 
						|
            users_in_room = set()
 | 
						|
            member_event_ids = [
 | 
						|
                e_id for key, e_id in state.items() if key[0] == EventTypes.Member
 | 
						|
            ]
 | 
						|
 | 
						|
            # We check if we have any of the member event ids in the event cache
 | 
						|
            # before we ask the DB
 | 
						|
 | 
						|
            # We don't update the event cache hit ratio as it completely throws off
 | 
						|
            # the hit ratio counts. After all, we don't populate the cache if we
 | 
						|
            # miss it here
 | 
						|
            event_map = self._get_events_from_local_cache(
 | 
						|
                member_event_ids, update_metrics=False
 | 
						|
            )
 | 
						|
 | 
						|
            missing_member_event_ids = []
 | 
						|
            for event_id in member_event_ids:
 | 
						|
                ev_entry = event_map.get(event_id)
 | 
						|
                if ev_entry and not ev_entry.event.rejected_reason:
 | 
						|
                    if ev_entry.event.membership == Membership.JOIN:
 | 
						|
                        users_in_room.add(ev_entry.event.state_key)
 | 
						|
                else:
 | 
						|
                    missing_member_event_ids.append(event_id)
 | 
						|
 | 
						|
            if missing_member_event_ids:
 | 
						|
                event_to_memberships = (
 | 
						|
                    await self._get_user_ids_from_membership_event_ids(
 | 
						|
                        missing_member_event_ids
 | 
						|
                    )
 | 
						|
                )
 | 
						|
                users_in_room.update(
 | 
						|
                    user_id for user_id in event_to_memberships.values() if user_id
 | 
						|
                )
 | 
						|
 | 
						|
            return users_in_room
 | 
						|
 | 
						|
    @cached(
 | 
						|
        max_entries=10000,
 | 
						|
        # This name matches the old function that has been replaced - the cache name
 | 
						|
        # is kept here to maintain backwards compatibility.
 | 
						|
        name="_get_joined_profile_from_event_id",
 | 
						|
    )
 | 
						|
    def _get_user_id_from_membership_event_id(
 | 
						|
        self, event_id: str
 | 
						|
    ) -> Optional[Tuple[str, ProfileInfo]]:
 | 
						|
        raise NotImplementedError()
 | 
						|
 | 
						|
    @cachedList(
 | 
						|
        cached_method_name="_get_user_id_from_membership_event_id",
 | 
						|
        list_name="event_ids",
 | 
						|
    )
 | 
						|
    async def _get_user_ids_from_membership_event_ids(
 | 
						|
        self, event_ids: Iterable[str]
 | 
						|
    ) -> Mapping[str, Optional[str]]:
 | 
						|
        """For given set of member event_ids check if they point to a join
 | 
						|
        event.
 | 
						|
 | 
						|
        Args:
 | 
						|
            event_ids: The member event IDs to lookup
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Map from event ID to `user_id`, or None if event is not a join.
 | 
						|
        """
 | 
						|
 | 
						|
        rows = await self.db_pool.simple_select_many_batch(
 | 
						|
            table="room_memberships",
 | 
						|
            column="event_id",
 | 
						|
            iterable=event_ids,
 | 
						|
            retcols=("user_id", "event_id"),
 | 
						|
            keyvalues={"membership": Membership.JOIN},
 | 
						|
            batch_size=1000,
 | 
						|
            desc="_get_user_ids_from_membership_event_ids",
 | 
						|
        )
 | 
						|
 | 
						|
        return {row["event_id"]: row["user_id"] for row in rows}
 | 
						|
 | 
						|
    @cached(max_entries=10000)
 | 
						|
    async def is_host_joined(self, room_id: str, host: str) -> bool:
 | 
						|
        return await self._check_host_room_membership(room_id, host, Membership.JOIN)
 | 
						|
 | 
						|
    @cached(max_entries=10000)
 | 
						|
    async def is_host_invited(self, room_id: str, host: str) -> bool:
 | 
						|
        return await self._check_host_room_membership(room_id, host, Membership.INVITE)
 | 
						|
 | 
						|
    async def _check_host_room_membership(
 | 
						|
        self, room_id: str, host: str, membership: str
 | 
						|
    ) -> bool:
 | 
						|
        if "%" in host or "_" in host:
 | 
						|
            raise Exception("Invalid host name")
 | 
						|
 | 
						|
        sql = """
 | 
						|
            SELECT state_key FROM current_state_events
 | 
						|
            WHERE membership = ?
 | 
						|
                AND type = 'm.room.member'
 | 
						|
                AND room_id = ?
 | 
						|
                AND state_key LIKE ?
 | 
						|
            LIMIT 1
 | 
						|
        """
 | 
						|
 | 
						|
        # We do need to be careful to ensure that host doesn't have any wild cards
 | 
						|
        # in it, but we checked above for known ones and we'll check below that
 | 
						|
        # the returned user actually has the correct domain.
 | 
						|
        like_clause = "%:" + host
 | 
						|
 | 
						|
        rows = await self.db_pool.execute(
 | 
						|
            "is_host_joined", None, sql, membership, room_id, like_clause
 | 
						|
        )
 | 
						|
 | 
						|
        if not rows:
 | 
						|
            return False
 | 
						|
 | 
						|
        user_id = rows[0][0]
 | 
						|
        if get_domain_from_id(user_id) != host:
 | 
						|
            # This can only happen if the host name has something funky in it
 | 
						|
            raise Exception("Invalid host name")
 | 
						|
 | 
						|
        return True
 | 
						|
 | 
						|
    @cached(iterable=True, max_entries=10000)
 | 
						|
    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
 | 
						|
        """Get current hosts in room based on current state."""
 | 
						|
 | 
						|
        # First we check if we already have `get_users_in_room` in the cache, as
 | 
						|
        # we can just calculate result from that
 | 
						|
        users = self.get_users_in_room.cache.get_immediate(
 | 
						|
            (room_id,), None, update_metrics=False
 | 
						|
        )
 | 
						|
        if users is not None:
 | 
						|
            return {get_domain_from_id(u) for u in users}
 | 
						|
 | 
						|
        if isinstance(self.database_engine, Sqlite3Engine):
 | 
						|
            # If we're using SQLite then let's just always use
 | 
						|
            # `get_users_in_room` rather than funky SQL.
 | 
						|
            users = await self.get_users_in_room(room_id)
 | 
						|
            return {get_domain_from_id(u) for u in users}
 | 
						|
 | 
						|
        # For PostgreSQL we can use a regex to pull out the domains from the
 | 
						|
        # joined users in `current_state_events` via regex.
 | 
						|
 | 
						|
        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
 | 
						|
            sql = """
 | 
						|
                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
 | 
						|
                FROM current_state_events
 | 
						|
                WHERE
 | 
						|
                    type = 'm.room.member'
 | 
						|
                    AND membership = 'join'
 | 
						|
                    AND room_id = ?
 | 
						|
            """
 | 
						|
            txn.execute(sql, (room_id,))
 | 
						|
            return {d for d, in txn}
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_current_hosts_in_room", get_current_hosts_in_room_txn
 | 
						|
        )
 | 
						|
 | 
						|
    @cached(iterable=True, max_entries=10000)
 | 
						|
    async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
 | 
						|
        """
 | 
						|
        Get current hosts in room based on current state.
 | 
						|
 | 
						|
        The heuristic of sorting by servers who have been in the room the
 | 
						|
        longest is good because they're most likely to have anything we ask
 | 
						|
        about.
 | 
						|
 | 
						|
        For SQLite the returned list is not ordered, as SQLite doesn't support
 | 
						|
        the appropriate SQL.
 | 
						|
 | 
						|
        Uses `m.room.member`s in the room state at the current forward
 | 
						|
        extremities to determine which hosts are in the room.
 | 
						|
 | 
						|
        Will return inaccurate results for rooms with partial state, since the
 | 
						|
        state for the forward extremities of those rooms will exclude most
 | 
						|
        members. We may also calculate room state incorrectly for such rooms and
 | 
						|
        believe that a host is or is not in the room when the opposite is true.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Returns a list of servers sorted by longest in the room first. (aka.
 | 
						|
            sorted by join with the lowest depth first).
 | 
						|
        """
 | 
						|
 | 
						|
        if isinstance(self.database_engine, Sqlite3Engine):
 | 
						|
            # If we're using SQLite then let's just always use
 | 
						|
            # `get_users_in_room` rather than funky SQL.
 | 
						|
 | 
						|
            domains = await self.get_current_hosts_in_room(room_id)
 | 
						|
            return tuple(domains)
 | 
						|
 | 
						|
        # For PostgreSQL we can use a regex to pull out the domains from the
 | 
						|
        # joined users in `current_state_events` via regex.
 | 
						|
 | 
						|
        def get_current_hosts_in_room_ordered_txn(
 | 
						|
            txn: LoggingTransaction,
 | 
						|
        ) -> Tuple[str, ...]:
 | 
						|
            # Returns a list of servers currently joined in the room sorted by
 | 
						|
            # longest in the room first (aka. with the lowest depth). The
 | 
						|
            # heuristic of sorting by servers who have been in the room the
 | 
						|
            # longest is good because they're most likely to have anything we
 | 
						|
            # ask about.
 | 
						|
            sql = """
 | 
						|
                SELECT
 | 
						|
                    /* Match the domain part of the MXID */
 | 
						|
                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
 | 
						|
                FROM current_state_events c
 | 
						|
                /* Get the depth of the event from the events table */
 | 
						|
                INNER JOIN events AS e USING (event_id)
 | 
						|
                WHERE
 | 
						|
                    /* Find any join state events in the room */
 | 
						|
                    c.type = 'm.room.member'
 | 
						|
                    AND c.membership = 'join'
 | 
						|
                    AND c.room_id = ?
 | 
						|
                /* Group all state events from the same domain into their own buckets (groups) */
 | 
						|
                GROUP BY server_domain
 | 
						|
                /* Sorted by lowest depth first */
 | 
						|
                ORDER BY min(e.depth) ASC;
 | 
						|
            """
 | 
						|
            txn.execute(sql, (room_id,))
 | 
						|
            # `server_domain` will be `NULL` for malformed MXIDs with no colons.
 | 
						|
            return tuple(d for d, in txn if d is not None)
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
 | 
						|
        )
 | 
						|
 | 
						|
    async def _get_approximate_current_memberships_in_room(
 | 
						|
        self, room_id: str
 | 
						|
    ) -> Mapping[str, Optional[str]]:
 | 
						|
        """Build a map from event id to membership, for all events in the current state.
 | 
						|
 | 
						|
        The event ids of non-memberships events (e.g. `m.room.power_levels`) are present
 | 
						|
        in the result, mapped to values of `None`.
 | 
						|
 | 
						|
        The result is approximate for partially-joined rooms. It is fully accurate
 | 
						|
        for fully-joined rooms.
 | 
						|
        """
 | 
						|
 | 
						|
        rows = await self.db_pool.simple_select_list(
 | 
						|
            "current_state_events",
 | 
						|
            keyvalues={"room_id": room_id},
 | 
						|
            retcols=("event_id", "membership"),
 | 
						|
            desc="has_completed_background_updates",
 | 
						|
        )
 | 
						|
        return {row["event_id"]: row["membership"] for row in rows}
 | 
						|
 | 
						|
    @cached(max_entries=10000)
 | 
						|
    def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
 | 
						|
        return _JoinedHostsCache()
 | 
						|
 | 
						|
    @cached(num_args=2)
 | 
						|
    async def did_forget(self, user_id: str, room_id: str) -> bool:
 | 
						|
        """Returns whether user_id has elected to discard history for room_id.
 | 
						|
 | 
						|
        Returns False if they have since re-joined."""
 | 
						|
 | 
						|
        def f(txn: LoggingTransaction) -> int:
 | 
						|
            sql = (
 | 
						|
                "SELECT"
 | 
						|
                "  COUNT(*)"
 | 
						|
                " FROM"
 | 
						|
                "  room_memberships"
 | 
						|
                " WHERE"
 | 
						|
                "  user_id = ?"
 | 
						|
                " AND"
 | 
						|
                "  room_id = ?"
 | 
						|
                " AND"
 | 
						|
                "  forgotten = 0"
 | 
						|
            )
 | 
						|
            txn.execute(sql, (user_id, room_id))
 | 
						|
            rows = txn.fetchall()
 | 
						|
            return rows[0][0]
 | 
						|
 | 
						|
        count = await self.db_pool.runInteraction("did_forget_membership", f)
 | 
						|
        return count == 0
 | 
						|
 | 
						|
    @cached()
 | 
						|
    async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]:
 | 
						|
        """Gets all rooms the user has forgotten.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID to query the rooms of.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The forgotten rooms.
 | 
						|
        """
 | 
						|
 | 
						|
        def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
 | 
						|
            # This is a slightly convoluted query that first looks up all rooms
 | 
						|
            # that the user has forgotten in the past, then rechecks that list
 | 
						|
            # to see if any have subsequently been updated. This is done so that
 | 
						|
            # we can use a partial index on `forgotten = 1` on the assumption
 | 
						|
            # that few users will actually forget many rooms.
 | 
						|
            #
 | 
						|
            # Note that a room is considered "forgotten" if *all* membership
 | 
						|
            # events for that user and room have the forgotten field set (as
 | 
						|
            # when a user forgets a room we update all rows for that user and
 | 
						|
            # room, not just the current one).
 | 
						|
            sql = """
 | 
						|
                SELECT room_id, (
 | 
						|
                    SELECT count(*) FROM room_memberships
 | 
						|
                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
 | 
						|
                ) AS count
 | 
						|
                FROM room_memberships AS m
 | 
						|
                WHERE user_id = ? AND forgotten = 1
 | 
						|
                GROUP BY room_id, user_id;
 | 
						|
            """
 | 
						|
            txn.execute(sql, (user_id,))
 | 
						|
            return {row[0] for row in txn if row[1] == 0}
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
 | 
						|
        )
 | 
						|
 | 
						|
    async def is_locally_forgotten_room(self, room_id: str) -> bool:
 | 
						|
        """Returns whether all local users have forgotten this room_id.
 | 
						|
 | 
						|
        Args:
 | 
						|
            room_id: The room ID to query.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Whether the room is forgotten.
 | 
						|
        """
 | 
						|
 | 
						|
        sql = """
 | 
						|
            SELECT count(*) > 0 FROM local_current_membership
 | 
						|
            INNER JOIN room_memberships USING (room_id, event_id)
 | 
						|
            WHERE
 | 
						|
                room_id = ?
 | 
						|
                AND forgotten = 0;
 | 
						|
        """
 | 
						|
 | 
						|
        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
 | 
						|
 | 
						|
        # `count(*)` returns always an integer
 | 
						|
        # If any rows still exist it means someone has not forgotten this room yet
 | 
						|
        return not rows[0][0]
 | 
						|
 | 
						|
    async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
 | 
						|
        """Get all rooms that the user has ever been in.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: The user ID to get the rooms of.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Set of room IDs.
 | 
						|
        """
 | 
						|
 | 
						|
        room_ids = await self.db_pool.simple_select_onecol(
 | 
						|
            table="room_memberships",
 | 
						|
            keyvalues={"membership": Membership.JOIN, "user_id": user_id},
 | 
						|
            retcol="room_id",
 | 
						|
            desc="get_rooms_user_has_been_in",
 | 
						|
        )
 | 
						|
 | 
						|
        return set(room_ids)
 | 
						|
 | 
						|
    @cached(max_entries=5000)
 | 
						|
    async def _get_membership_from_event_id(
 | 
						|
        self, member_event_id: str
 | 
						|
    ) -> Optional[EventIdMembership]:
 | 
						|
        raise NotImplementedError()
 | 
						|
 | 
						|
    @cachedList(
 | 
						|
        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
 | 
						|
    )
 | 
						|
    async def get_membership_from_event_ids(
 | 
						|
        self, member_event_ids: Iterable[str]
 | 
						|
    ) -> Mapping[str, Optional[EventIdMembership]]:
 | 
						|
        """Get user_id and membership of a set of event IDs.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Mapping from event ID to `EventIdMembership` if the event is a
 | 
						|
            membership event, otherwise the value is None.
 | 
						|
        """
 | 
						|
 | 
						|
        rows = await self.db_pool.simple_select_many_batch(
 | 
						|
            table="room_memberships",
 | 
						|
            column="event_id",
 | 
						|
            iterable=member_event_ids,
 | 
						|
            retcols=("user_id", "membership", "event_id"),
 | 
						|
            keyvalues={},
 | 
						|
            batch_size=500,
 | 
						|
            desc="get_membership_from_event_ids",
 | 
						|
        )
 | 
						|
 | 
						|
        return {
 | 
						|
            row["event_id"]: EventIdMembership(
 | 
						|
                membership=row["membership"], user_id=row["user_id"]
 | 
						|
            )
 | 
						|
            for row in rows
 | 
						|
        }
 | 
						|
 | 
						|
    async def is_local_host_in_room_ignoring_users(
 | 
						|
        self, room_id: str, ignore_users: Collection[str]
 | 
						|
    ) -> bool:
 | 
						|
        """Check if there are any local users, excluding those in the given
 | 
						|
        list, in the room.
 | 
						|
        """
 | 
						|
 | 
						|
        clause, args = make_in_list_sql_clause(
 | 
						|
            self.database_engine, "user_id", ignore_users
 | 
						|
        )
 | 
						|
 | 
						|
        sql = """
 | 
						|
            SELECT 1 FROM local_current_membership
 | 
						|
            WHERE
 | 
						|
                room_id = ? AND membership = ?
 | 
						|
                AND NOT (%s)
 | 
						|
                LIMIT 1
 | 
						|
        """ % (
 | 
						|
            clause,
 | 
						|
        )
 | 
						|
 | 
						|
        def _is_local_host_in_room_ignoring_users_txn(
 | 
						|
            txn: LoggingTransaction,
 | 
						|
        ) -> bool:
 | 
						|
            txn.execute(sql, (room_id, Membership.JOIN, *args))
 | 
						|
 | 
						|
            return bool(txn.fetchone())
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "is_local_host_in_room_ignoring_users",
 | 
						|
            _is_local_host_in_room_ignoring_users_txn,
 | 
						|
        )
 | 
						|
 | 
						|
    async def forget(self, user_id: str, room_id: str) -> None:
 | 
						|
        """Indicate that user_id wishes to discard history for room_id."""
 | 
						|
 | 
						|
        def f(txn: LoggingTransaction) -> None:
 | 
						|
            self.db_pool.simple_update_txn(
 | 
						|
                txn,
 | 
						|
                table="room_memberships",
 | 
						|
                keyvalues={"user_id": user_id, "room_id": room_id},
 | 
						|
                updatevalues={"forgotten": 1},
 | 
						|
            )
 | 
						|
 | 
						|
            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
 | 
						|
            self._invalidate_cache_and_stream(
 | 
						|
                txn, self.get_forgotten_rooms_for_user, (user_id,)
 | 
						|
            )
 | 
						|
 | 
						|
        await self.db_pool.runInteraction("forget_membership", f)
 | 
						|
 | 
						|
    async def get_room_forgetter_stream_pos(self) -> int:
 | 
						|
        """Get the stream position of the background process to forget rooms when left
 | 
						|
        by users.
 | 
						|
        """
 | 
						|
        return await self.db_pool.simple_select_one_onecol(
 | 
						|
            table="room_forgetter_stream_pos",
 | 
						|
            keyvalues={},
 | 
						|
            retcol="stream_id",
 | 
						|
            desc="room_forgetter_stream_pos",
 | 
						|
        )
 | 
						|
 | 
						|
    async def update_room_forgetter_stream_pos(self, stream_id: int) -> None:
 | 
						|
        """Update the stream position of the background process to forget rooms when
 | 
						|
        left by users.
 | 
						|
 | 
						|
        Must only be used by the worker running the background process.
 | 
						|
        """
 | 
						|
        assert self.hs.config.worker.run_background_tasks
 | 
						|
 | 
						|
        await self.db_pool.simple_update_one(
 | 
						|
            table="room_forgetter_stream_pos",
 | 
						|
            keyvalues={},
 | 
						|
            updatevalues={"stream_id": stream_id},
 | 
						|
            desc="room_forgetter_stream_pos",
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        database: DatabasePool,
 | 
						|
        db_conn: LoggingDatabaseConnection,
 | 
						|
        hs: "HomeServer",
 | 
						|
    ):
 | 
						|
        super().__init__(database, db_conn, hs)
 | 
						|
        self.db_pool.updates.register_background_update_handler(
 | 
						|
            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
 | 
						|
        )
 | 
						|
        self.db_pool.updates.register_background_update_handler(
 | 
						|
            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
 | 
						|
            self._background_current_state_membership,
 | 
						|
        )
 | 
						|
        self.db_pool.updates.register_background_index_update(
 | 
						|
            "room_membership_forgotten_idx",
 | 
						|
            index_name="room_memberships_user_room_forgotten",
 | 
						|
            table="room_memberships",
 | 
						|
            columns=["user_id", "room_id"],
 | 
						|
            where_clause="forgotten = 1",
 | 
						|
        )
 | 
						|
        self.db_pool.updates.register_background_index_update(
 | 
						|
            "room_membership_user_room_index",
 | 
						|
            index_name="room_membership_user_room_idx",
 | 
						|
            table="room_memberships",
 | 
						|
            columns=["user_id", "room_id"],
 | 
						|
        )
 | 
						|
 | 
						|
    async def _background_add_membership_profile(
 | 
						|
        self, progress: JsonDict, batch_size: int
 | 
						|
    ) -> int:
 | 
						|
        target_min_stream_id = progress.get(
 | 
						|
            "target_min_stream_id_inclusive", self._min_stream_order_on_start  # type: ignore[attr-defined]
 | 
						|
        )
 | 
						|
        max_stream_id = progress.get(
 | 
						|
            "max_stream_id_exclusive", self._stream_order_on_start + 1  # type: ignore[attr-defined]
 | 
						|
        )
 | 
						|
 | 
						|
        def add_membership_profile_txn(txn: LoggingTransaction) -> int:
 | 
						|
            sql = """
 | 
						|
                SELECT stream_ordering, event_id, events.room_id, event_json.json
 | 
						|
                FROM events
 | 
						|
                INNER JOIN event_json USING (event_id)
 | 
						|
                WHERE ? <= stream_ordering AND stream_ordering < ?
 | 
						|
                AND type = 'm.room.member'
 | 
						|
                ORDER BY stream_ordering DESC
 | 
						|
                LIMIT ?
 | 
						|
            """
 | 
						|
 | 
						|
            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 | 
						|
 | 
						|
            rows = self.db_pool.cursor_to_dict(txn)
 | 
						|
            if not rows:
 | 
						|
                return 0
 | 
						|
 | 
						|
            min_stream_id = rows[-1]["stream_ordering"]
 | 
						|
 | 
						|
            to_update = []
 | 
						|
            for row in rows:
 | 
						|
                event_id = row["event_id"]
 | 
						|
                room_id = row["room_id"]
 | 
						|
                try:
 | 
						|
                    event_json = db_to_json(row["json"])
 | 
						|
                    content = event_json["content"]
 | 
						|
                except Exception:
 | 
						|
                    continue
 | 
						|
 | 
						|
                display_name = content.get("displayname", None)
 | 
						|
                avatar_url = content.get("avatar_url", None)
 | 
						|
 | 
						|
                if display_name or avatar_url:
 | 
						|
                    to_update.append((display_name, avatar_url, event_id, room_id))
 | 
						|
 | 
						|
            to_update_sql = """
 | 
						|
                UPDATE room_memberships SET display_name = ?, avatar_url = ?
 | 
						|
                WHERE event_id = ? AND room_id = ?
 | 
						|
            """
 | 
						|
            txn.execute_batch(to_update_sql, to_update)
 | 
						|
 | 
						|
            progress = {
 | 
						|
                "target_min_stream_id_inclusive": target_min_stream_id,
 | 
						|
                "max_stream_id_exclusive": min_stream_id,
 | 
						|
            }
 | 
						|
 | 
						|
            self.db_pool.updates._background_update_progress_txn(
 | 
						|
                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
 | 
						|
            )
 | 
						|
 | 
						|
            return len(rows)
 | 
						|
 | 
						|
        result = await self.db_pool.runInteraction(
 | 
						|
            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
 | 
						|
        )
 | 
						|
 | 
						|
        if not result:
 | 
						|
            await self.db_pool.updates._end_background_update(
 | 
						|
                _MEMBERSHIP_PROFILE_UPDATE_NAME
 | 
						|
            )
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
    async def _background_current_state_membership(
 | 
						|
        self, progress: JsonDict, batch_size: int
 | 
						|
    ) -> int:
 | 
						|
        """Update the new membership column on current_state_events.
 | 
						|
 | 
						|
        This works by iterating over all rooms in alphebetical order.
 | 
						|
        """
 | 
						|
 | 
						|
        def _background_current_state_membership_txn(
 | 
						|
            txn: LoggingTransaction, last_processed_room: str
 | 
						|
        ) -> Tuple[int, bool]:
 | 
						|
            processed = 0
 | 
						|
            while processed < batch_size:
 | 
						|
                txn.execute(
 | 
						|
                    """
 | 
						|
                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
 | 
						|
                    """,
 | 
						|
                    (last_processed_room,),
 | 
						|
                )
 | 
						|
                row = txn.fetchone()
 | 
						|
                if not row or not row[0]:
 | 
						|
                    return processed, True
 | 
						|
 | 
						|
                (next_room,) = row
 | 
						|
 | 
						|
                sql = """
 | 
						|
                    UPDATE current_state_events
 | 
						|
                    SET membership = (
 | 
						|
                        SELECT membership FROM room_memberships
 | 
						|
                        WHERE event_id = current_state_events.event_id
 | 
						|
                    )
 | 
						|
                    WHERE room_id = ?
 | 
						|
                """
 | 
						|
                txn.execute(sql, (next_room,))
 | 
						|
                processed += txn.rowcount
 | 
						|
 | 
						|
                last_processed_room = next_room
 | 
						|
 | 
						|
            self.db_pool.updates._background_update_progress_txn(
 | 
						|
                txn,
 | 
						|
                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
 | 
						|
                {"last_processed_room": last_processed_room},
 | 
						|
            )
 | 
						|
 | 
						|
            return processed, False
 | 
						|
 | 
						|
        # If we haven't got a last processed room then just use the empty
 | 
						|
        # string, which will compare before all room IDs correctly.
 | 
						|
        last_processed_room = progress.get("last_processed_room", "")
 | 
						|
 | 
						|
        row_count, finished = await self.db_pool.runInteraction(
 | 
						|
            "_background_current_state_membership_update",
 | 
						|
            _background_current_state_membership_txn,
 | 
						|
            last_processed_room,
 | 
						|
        )
 | 
						|
 | 
						|
        if finished:
 | 
						|
            await self.db_pool.updates._end_background_update(
 | 
						|
                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
 | 
						|
            )
 | 
						|
 | 
						|
        return row_count
 | 
						|
 | 
						|
 | 
						|
class RoomMemberStore(
 | 
						|
    RoomMemberWorkerStore,
 | 
						|
    RoomMemberBackgroundUpdateStore,
 | 
						|
    CacheInvalidationWorkerStore,
 | 
						|
):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        database: DatabasePool,
 | 
						|
        db_conn: LoggingDatabaseConnection,
 | 
						|
        hs: "HomeServer",
 | 
						|
    ):
 | 
						|
        super().__init__(database, db_conn, hs)
 | 
						|
 | 
						|
 | 
						|
def extract_heroes_from_room_summary(
 | 
						|
    details: Mapping[str, MemberSummary], me: str
 | 
						|
) -> List[str]:
 | 
						|
    """Determine the users that represent a room, from the perspective of the `me` user.
 | 
						|
 | 
						|
    The rules which say which users we select are specified in the "Room Summary"
 | 
						|
    section of
 | 
						|
    https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
 | 
						|
 | 
						|
    Returns a list (possibly empty) of heroes' mxids.
 | 
						|
    """
 | 
						|
    empty_ms = MemberSummary([], 0)
 | 
						|
 | 
						|
    joined_user_ids = [
 | 
						|
        r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
 | 
						|
    ]
 | 
						|
    invited_user_ids = [
 | 
						|
        r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
 | 
						|
    ]
 | 
						|
    gone_user_ids = [
 | 
						|
        r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
 | 
						|
    ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
 | 
						|
 | 
						|
    # FIXME: order by stream ordering rather than as returned by SQL
 | 
						|
    if joined_user_ids or invited_user_ids:
 | 
						|
        return sorted(joined_user_ids + invited_user_ids)[0:5]
 | 
						|
    else:
 | 
						|
        return sorted(gone_user_ids)[0:5]
 | 
						|
 | 
						|
 | 
						|
@attr.s(slots=True, auto_attribs=True)
 | 
						|
class _JoinedHostsCache:
 | 
						|
    """The cached data used by the `_get_joined_hosts_cache`."""
 | 
						|
 | 
						|
    # Dict of host to the set of their users in the room at the state group.
 | 
						|
    hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)
 | 
						|
 | 
						|
    # The state group `hosts_to_joined_users` is derived from. Will be an object
 | 
						|
    # if the instance is newly created or if the state is not based on a state
 | 
						|
    # group. (An object is used as a sentinel value to ensure that it never is
 | 
						|
    # equal to anything else).
 | 
						|
    state_group: Union[object, int] = attr.Factory(object)
 | 
						|
 | 
						|
    def __len__(self) -> int:
 | 
						|
        return sum(len(v) for v in self.hosts_to_joined_users.values())
 |