Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)
This should use fewer allocations and improves type hints.pull/16559/head
							parent
							
								
									c14a7de6af
								
							
						
					
					
						commit
						9407d5ba78
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Reduce memory allocations.
 | 
			
		||||
| 
						 | 
				
			
			@ -103,10 +103,10 @@ class DeactivateAccountHandler:
 | 
			
		|||
        # Attempt to unbind any known bound threepids to this account from identity
 | 
			
		||||
        # server(s).
 | 
			
		||||
        bound_threepids = await self.store.user_get_bound_threepids(user_id)
 | 
			
		||||
        for threepid in bound_threepids:
 | 
			
		||||
        for medium, address in bound_threepids:
 | 
			
		||||
            try:
 | 
			
		||||
                result = await self._identity_handler.try_unbind_threepid(
 | 
			
		||||
                    user_id, threepid["medium"], threepid["address"], id_server
 | 
			
		||||
                    user_id, medium, address, id_server
 | 
			
		||||
                )
 | 
			
		||||
            except Exception:
 | 
			
		||||
                # Do we want this to be a fatal error or should we carry on?
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1206,10 +1206,7 @@ class SsoHandler:
 | 
			
		|||
        # We have no guarantee that all the devices of that session are for the same
 | 
			
		||||
        # `user_id`. Hence, we have to iterate over the list of devices and log them out
 | 
			
		||||
        # one by one.
 | 
			
		||||
        for device in devices:
 | 
			
		||||
            user_id = device["user_id"]
 | 
			
		||||
            device_id = device["device_id"]
 | 
			
		||||
 | 
			
		||||
        for user_id, device_id in devices:
 | 
			
		||||
            # If the user_id associated with that device/session is not the one we got
 | 
			
		||||
            # out of the `sub` claim, skip that device and show log an error.
 | 
			
		||||
            if expected_user_id is not None and user_id != expected_user_id:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -606,13 +606,16 @@ class DatabasePool:
 | 
			
		|||
 | 
			
		||||
        If the background updates have not completed, wait 15 sec and check again.
 | 
			
		||||
        """
 | 
			
		||||
        updates = await self.simple_select_list(
 | 
			
		||||
            "background_updates",
 | 
			
		||||
            keyvalues=None,
 | 
			
		||||
            retcols=["update_name"],
 | 
			
		||||
            desc="check_background_updates",
 | 
			
		||||
        updates = cast(
 | 
			
		||||
            List[Tuple[str]],
 | 
			
		||||
            await self.simple_select_list(
 | 
			
		||||
                "background_updates",
 | 
			
		||||
                keyvalues=None,
 | 
			
		||||
                retcols=["update_name"],
 | 
			
		||||
                desc="check_background_updates",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        background_update_names = [x["update_name"] for x in updates]
 | 
			
		||||
        background_update_names = [x[0] for x in updates]
 | 
			
		||||
 | 
			
		||||
        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
 | 
			
		||||
            if update_name not in background_update_names:
 | 
			
		||||
| 
						 | 
				
			
			@ -1804,9 +1807,9 @@ class DatabasePool:
 | 
			
		|||
        keyvalues: Optional[Dict[str, Any]],
 | 
			
		||||
        retcols: Collection[str],
 | 
			
		||||
        desc: str = "simple_select_list",
 | 
			
		||||
    ) -> List[Dict[str, Any]]:
 | 
			
		||||
    ) -> List[Tuple[Any, ...]]:
 | 
			
		||||
        """Executes a SELECT query on the named table, which may return zero or
 | 
			
		||||
        more rows, returning the result as a list of dicts.
 | 
			
		||||
        more rows, returning the result as a list of tuples.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            table: the table name
 | 
			
		||||
| 
						 | 
				
			
			@ -1817,8 +1820,7 @@ class DatabasePool:
 | 
			
		|||
            desc: description of the transaction, for logging and metrics
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A list of dictionaries, one per result row, each a mapping between the
 | 
			
		||||
            column names from `retcols` and that column's value for the row.
 | 
			
		||||
            A list of tuples, one per result row, each the retcolumn's value for the row.
 | 
			
		||||
        """
 | 
			
		||||
        return await self.runInteraction(
 | 
			
		||||
            desc,
 | 
			
		||||
| 
						 | 
				
			
			@ -1836,9 +1838,9 @@ class DatabasePool:
 | 
			
		|||
        table: str,
 | 
			
		||||
        keyvalues: Optional[Dict[str, Any]],
 | 
			
		||||
        retcols: Iterable[str],
 | 
			
		||||
    ) -> List[Dict[str, Any]]:
 | 
			
		||||
    ) -> List[Tuple[Any, ...]]:
 | 
			
		||||
        """Executes a SELECT query on the named table, which may return zero or
 | 
			
		||||
        more rows, returning the result as a list of dicts.
 | 
			
		||||
        more rows, returning the result as a list of tuples.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            txn: Transaction object
 | 
			
		||||
| 
						 | 
				
			
			@ -1849,8 +1851,7 @@ class DatabasePool:
 | 
			
		|||
            retcols: the names of the columns to return
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A list of dictionaries, one per result row, each a mapping between the
 | 
			
		||||
            column names from `retcols` and that column's value for the row.
 | 
			
		||||
            A list of tuples, one per result row, each the retcolumn's value for the row.
 | 
			
		||||
        """
 | 
			
		||||
        if keyvalues:
 | 
			
		||||
            sql = "SELECT %s FROM %s WHERE %s" % (
 | 
			
		||||
| 
						 | 
				
			
			@ -1863,7 +1864,7 @@ class DatabasePool:
 | 
			
		|||
            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
 | 
			
		||||
            txn.execute(sql)
 | 
			
		||||
 | 
			
		||||
        return cls.cursor_to_dict(txn)
 | 
			
		||||
        return txn.fetchall()
 | 
			
		||||
 | 
			
		||||
    async def simple_select_many_batch(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 | 
			
		|||
 | 
			
		||||
        def get_account_data_for_room_txn(
 | 
			
		||||
            txn: LoggingTransaction,
 | 
			
		||||
        ) -> Dict[str, JsonDict]:
 | 
			
		||||
            rows = self.db_pool.simple_select_list_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                "room_account_data",
 | 
			
		||||
                {"user_id": user_id, "room_id": room_id},
 | 
			
		||||
                ["account_data_type", "content"],
 | 
			
		||||
        ) -> Dict[str, JsonMapping]:
 | 
			
		||||
            rows = cast(
 | 
			
		||||
                List[Tuple[str, str]],
 | 
			
		||||
                self.db_pool.simple_select_list_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="room_account_data",
 | 
			
		||||
                    keyvalues={"user_id": user_id, "room_id": room_id},
 | 
			
		||||
                    retcols=["account_data_type", "content"],
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                row["account_data_type"]: db_to_json(row["content"]) for row in rows
 | 
			
		||||
                account_data_type: db_to_json(content)
 | 
			
		||||
                for account_data_type, content in rows
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
 | 
			
		|||
        Returns:
 | 
			
		||||
            A list of ApplicationServices, which may be empty.
 | 
			
		||||
        """
 | 
			
		||||
        results = await self.db_pool.simple_select_list(
 | 
			
		||||
            "application_services_state", {"state": state.value}, ["as_id"]
 | 
			
		||||
        results = cast(
 | 
			
		||||
            List[Tuple[str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="application_services_state",
 | 
			
		||||
                keyvalues={"state": state.value},
 | 
			
		||||
                retcols=("as_id",),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # NB: This assumes this class is linked with ApplicationServiceStore
 | 
			
		||||
        as_list = self.get_app_services()
 | 
			
		||||
        services = []
 | 
			
		||||
 | 
			
		||||
        for res in results:
 | 
			
		||||
        for (as_id,) in results:
 | 
			
		||||
            for service in as_list:
 | 
			
		||||
                if service.id == res["as_id"]:
 | 
			
		||||
                if service.id == as_id:
 | 
			
		||||
                    services.append(service)
 | 
			
		||||
        return services
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
 | 
			
		|||
        if device_id is not None:
 | 
			
		||||
            keyvalues["device_id"] = device_id
 | 
			
		||||
 | 
			
		||||
        res = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="devices",
 | 
			
		||||
            keyvalues=keyvalues,
 | 
			
		||||
            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
 | 
			
		||||
        res = cast(
 | 
			
		||||
            List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="devices",
 | 
			
		||||
                keyvalues=keyvalues,
 | 
			
		||||
                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
 | 
			
		||||
                user_id=d["user_id"],
 | 
			
		||||
                device_id=d["device_id"],
 | 
			
		||||
                ip=d["ip"],
 | 
			
		||||
                user_agent=d["user_agent"],
 | 
			
		||||
                last_seen=d["last_seen"],
 | 
			
		||||
            (user_id, device_id): DeviceLastConnectionInfo(
 | 
			
		||||
                user_id=user_id,
 | 
			
		||||
                device_id=device_id,
 | 
			
		||||
                ip=ip,
 | 
			
		||||
                user_agent=user_agent,
 | 
			
		||||
                last_seen=last_seen,
 | 
			
		||||
            )
 | 
			
		||||
            for d in res
 | 
			
		||||
            for user_id, ip, user_agent, device_id, last_seen in res
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    async def _get_user_ip_and_agents_from_database(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
            allow_none=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
 | 
			
		||||
    async def get_devices_by_user(
 | 
			
		||||
        self, user_id: str
 | 
			
		||||
    ) -> Dict[str, Dict[str, Optional[str]]]:
 | 
			
		||||
        """Retrieve all of a user's registered devices. Only returns devices
 | 
			
		||||
        that are not marked as hidden.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
            user_id:
 | 
			
		||||
        Returns:
 | 
			
		||||
            A mapping from device_id to a dict containing "device_id", "user_id"
 | 
			
		||||
            and "display_name" for each device.
 | 
			
		||||
            and "display_name" for each device. Display name may be null.
 | 
			
		||||
        """
 | 
			
		||||
        devices = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="devices",
 | 
			
		||||
            keyvalues={"user_id": user_id, "hidden": False},
 | 
			
		||||
            retcols=("user_id", "device_id", "display_name"),
 | 
			
		||||
            desc="get_devices_by_user",
 | 
			
		||||
        devices = cast(
 | 
			
		||||
            List[Tuple[str, str, Optional[str]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="devices",
 | 
			
		||||
                keyvalues={"user_id": user_id, "hidden": False},
 | 
			
		||||
                retcols=("user_id", "device_id", "display_name"),
 | 
			
		||||
                desc="get_devices_by_user",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return {d["device_id"]: d for d in devices}
 | 
			
		||||
        return {
 | 
			
		||||
            d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
 | 
			
		||||
            for d in devices
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    async def get_devices_by_auth_provider_session_id(
 | 
			
		||||
        self, auth_provider_id: str, auth_provider_session_id: str
 | 
			
		||||
    ) -> List[Dict[str, Any]]:
 | 
			
		||||
    ) -> List[Tuple[str, str]]:
 | 
			
		||||
        """Retrieve the list of devices associated with a SSO IdP session ID.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
| 
						 | 
				
			
			@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            A list of dicts containing the device_id and the user_id of each device
 | 
			
		||||
        """
 | 
			
		||||
        return await self.db_pool.simple_select_list(
 | 
			
		||||
            table="device_auth_providers",
 | 
			
		||||
            keyvalues={
 | 
			
		||||
                "auth_provider_id": auth_provider_id,
 | 
			
		||||
                "auth_provider_session_id": auth_provider_session_id,
 | 
			
		||||
            },
 | 
			
		||||
            retcols=("user_id", "device_id"),
 | 
			
		||||
            desc="get_devices_by_auth_provider_session_id",
 | 
			
		||||
        return cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="device_auth_providers",
 | 
			
		||||
                keyvalues={
 | 
			
		||||
                    "auth_provider_id": auth_provider_id,
 | 
			
		||||
                    "auth_provider_session_id": auth_provider_session_id,
 | 
			
		||||
                },
 | 
			
		||||
                retcols=("user_id", "device_id"),
 | 
			
		||||
                desc="get_devices_by_auth_provider_session_id",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @trace
 | 
			
		||||
| 
						 | 
				
			
			@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
    async def get_cached_devices_for_user(
 | 
			
		||||
        self, user_id: str
 | 
			
		||||
    ) -> Mapping[str, JsonMapping]:
 | 
			
		||||
        devices = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="device_lists_remote_cache",
 | 
			
		||||
            keyvalues={"user_id": user_id},
 | 
			
		||||
            retcols=("device_id", "content"),
 | 
			
		||||
            desc="get_cached_devices_for_user",
 | 
			
		||||
        devices = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="device_lists_remote_cache",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=("device_id", "content"),
 | 
			
		||||
                desc="get_cached_devices_for_user",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {
 | 
			
		||||
            device["device_id"]: db_to_json(device["content"]) for device in devices
 | 
			
		||||
        }
 | 
			
		||||
        return {device[0]: db_to_json(device[1]) for device in devices}
 | 
			
		||||
 | 
			
		||||
    def get_cached_device_list_changes(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
            The IDs of users whose device lists need resync.
 | 
			
		||||
        """
 | 
			
		||||
        if user_ids:
 | 
			
		||||
            row_tuples = cast(
 | 
			
		||||
            rows = cast(
 | 
			
		||||
                List[Tuple[str]],
 | 
			
		||||
                await self.db_pool.simple_select_many_batch(
 | 
			
		||||
                    table="device_lists_remote_resync",
 | 
			
		||||
| 
						 | 
				
			
			@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
                    desc="get_user_ids_requiring_device_list_resync_with_iterable",
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return {row[0] for row in row_tuples}
 | 
			
		||||
        else:
 | 
			
		||||
            rows = cast(
 | 
			
		||||
                List[Dict[str, str]],
 | 
			
		||||
                List[Tuple[str]],
 | 
			
		||||
                await self.db_pool.simple_select_list(
 | 
			
		||||
                    table="device_lists_remote_resync",
 | 
			
		||||
                    keyvalues=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 | 
			
		|||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return {row["user_id"] for row in rows}
 | 
			
		||||
        return {row[0] for row in rows}
 | 
			
		||||
 | 
			
		||||
    async def mark_remote_users_device_caches_as_stale(
 | 
			
		||||
        self, user_ids: StrCollection
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from typing_extensions import Literal, TypedDict
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
 | 
			
		|||
            if session_id:
 | 
			
		||||
                keyvalues["session_id"] = session_id
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="e2e_room_keys",
 | 
			
		||||
            keyvalues=keyvalues,
 | 
			
		||||
            retcols=(
 | 
			
		||||
                "user_id",
 | 
			
		||||
                "room_id",
 | 
			
		||||
                "session_id",
 | 
			
		||||
                "first_message_index",
 | 
			
		||||
                "forwarded_count",
 | 
			
		||||
                "is_verified",
 | 
			
		||||
                "session_data",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str, int, int, int, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="e2e_room_keys",
 | 
			
		||||
                keyvalues=keyvalues,
 | 
			
		||||
                retcols=(
 | 
			
		||||
                    "room_id",
 | 
			
		||||
                    "session_id",
 | 
			
		||||
                    "first_message_index",
 | 
			
		||||
                    "forwarded_count",
 | 
			
		||||
                    "is_verified",
 | 
			
		||||
                    "session_data",
 | 
			
		||||
                ),
 | 
			
		||||
                desc="get_e2e_room_keys",
 | 
			
		||||
            ),
 | 
			
		||||
            desc="get_e2e_room_keys",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        sessions: Dict[
 | 
			
		||||
            Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
 | 
			
		||||
        ] = {"rooms": {}}
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
 | 
			
		||||
            room_entry["sessions"][row["session_id"]] = {
 | 
			
		||||
                "first_message_index": row["first_message_index"],
 | 
			
		||||
                "forwarded_count": row["forwarded_count"],
 | 
			
		||||
        for (
 | 
			
		||||
            room_id,
 | 
			
		||||
            session_id,
 | 
			
		||||
            first_message_index,
 | 
			
		||||
            forwarded_count,
 | 
			
		||||
            is_verified,
 | 
			
		||||
            session_data,
 | 
			
		||||
        ) in rows:
 | 
			
		||||
            room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
 | 
			
		||||
            room_entry["sessions"][session_id] = {
 | 
			
		||||
                "first_message_index": first_message_index,
 | 
			
		||||
                "forwarded_count": forwarded_count,
 | 
			
		||||
                # is_verified must be returned to the client as a boolean
 | 
			
		||||
                "is_verified": bool(row["is_verified"]),
 | 
			
		||||
                "session_data": db_to_json(row["session_data"]),
 | 
			
		||||
                "is_verified": bool(is_verified),
 | 
			
		||||
                "session_data": db_to_json(session_data),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return sessions
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 | 
			
		|||
        # keeping only the forward extremities (i.e. the events not referenced
 | 
			
		||||
        # by other events in the queue). We do this so that we can always
 | 
			
		||||
        # backpaginate in all the events we have dropped.
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="federation_inbound_events_staging",
 | 
			
		||||
            keyvalues={"room_id": room_id},
 | 
			
		||||
            retcols=("event_id", "event_json"),
 | 
			
		||||
            desc="prune_staged_events_in_room_fetch",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="federation_inbound_events_staging",
 | 
			
		||||
                keyvalues={"room_id": room_id},
 | 
			
		||||
                retcols=("event_id", "event_json"),
 | 
			
		||||
                desc="prune_staged_events_in_room_fetch",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Find the set of events referenced by those in the queue, as well as
 | 
			
		||||
        # collecting all the event IDs in the queue.
 | 
			
		||||
        referenced_events: Set[str] = set()
 | 
			
		||||
        seen_events: Set[str] = set()
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            event_id = row["event_id"]
 | 
			
		||||
        for event_id, event_json in rows:
 | 
			
		||||
            seen_events.add(event_id)
 | 
			
		||||
            event_d = db_to_json(row["event_json"])
 | 
			
		||||
            event_d = db_to_json(event_json)
 | 
			
		||||
 | 
			
		||||
            # We don't bother parsing the dicts into full blown event objects,
 | 
			
		||||
            # as that is needlessly expensive.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
#  See the License for the specific language governing permissions and
 | 
			
		||||
#  limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, FrozenSet
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 | 
			
		||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
 | 
			
		||||
| 
						 | 
				
			
			@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            the features currently enabled for the user
 | 
			
		||||
        """
 | 
			
		||||
        enabled = await self.db_pool.simple_select_list(
 | 
			
		||||
            "per_user_experimental_features",
 | 
			
		||||
            {"user_id": user_id, "enabled": True},
 | 
			
		||||
            ["feature"],
 | 
			
		||||
        enabled = cast(
 | 
			
		||||
            List[Tuple[str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="per_user_experimental_features",
 | 
			
		||||
                keyvalues={"user_id": user_id, "enabled": True},
 | 
			
		||||
                retcols=("feature",),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return frozenset(feature["feature"] for feature in enabled)
 | 
			
		||||
        return frozenset(feature[0] for feature in enabled)
 | 
			
		||||
 | 
			
		||||
    async def set_features_for_user(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
 | 
			
		|||
 | 
			
		||||
        If we have multiple entries for a given key ID, returns the most recent.
 | 
			
		||||
        """
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="server_keys_json",
 | 
			
		||||
            keyvalues={"server_name": server_name},
 | 
			
		||||
            retcols=(
 | 
			
		||||
                "key_id",
 | 
			
		||||
                "from_server",
 | 
			
		||||
                "ts_added_ms",
 | 
			
		||||
                "ts_valid_until_ms",
 | 
			
		||||
                "key_json",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="server_keys_json",
 | 
			
		||||
                keyvalues={"server_name": server_name},
 | 
			
		||||
                retcols=(
 | 
			
		||||
                    "key_id",
 | 
			
		||||
                    "from_server",
 | 
			
		||||
                    "ts_added_ms",
 | 
			
		||||
                    "ts_valid_until_ms",
 | 
			
		||||
                    "key_json",
 | 
			
		||||
                ),
 | 
			
		||||
                desc="get_server_keys_json_for_remote",
 | 
			
		||||
            ),
 | 
			
		||||
            desc="get_server_keys_json_for_remote",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not rows:
 | 
			
		||||
| 
						 | 
				
			
			@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
 | 
			
		|||
 | 
			
		||||
        # We sort the rows by ts_added_ms so that the most recently added entry
 | 
			
		||||
        # will stomp over older entries in the dictionary.
 | 
			
		||||
        rows.sort(key=lambda r: r["ts_added_ms"])
 | 
			
		||||
        rows.sort(key=lambda r: r[2])
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            row["key_id"]: FetchKeyResultForRemote(
 | 
			
		||||
            key_id: FetchKeyResultForRemote(
 | 
			
		||||
                # Cast to bytes since postgresql returns a memoryview.
 | 
			
		||||
                key_json=bytes(row["key_json"]),
 | 
			
		||||
                valid_until_ts=row["ts_valid_until_ms"],
 | 
			
		||||
                added_ts=row["ts_added_ms"],
 | 
			
		||||
                key_json=bytes(key_json),
 | 
			
		||||
                valid_until_ts=ts_valid_until_ms,
 | 
			
		||||
                added_ts=ts_added_ms,
 | 
			
		||||
            )
 | 
			
		||||
            for row in rows
 | 
			
		||||
            for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            "local_media_repository_thumbnails",
 | 
			
		||||
            {"media_id": media_id},
 | 
			
		||||
            (
 | 
			
		||||
                "thumbnail_width",
 | 
			
		||||
                "thumbnail_height",
 | 
			
		||||
                "thumbnail_method",
 | 
			
		||||
                "thumbnail_type",
 | 
			
		||||
                "thumbnail_length",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[int, int, str, str, int]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "local_media_repository_thumbnails",
 | 
			
		||||
                {"media_id": media_id},
 | 
			
		||||
                (
 | 
			
		||||
                    "thumbnail_width",
 | 
			
		||||
                    "thumbnail_height",
 | 
			
		||||
                    "thumbnail_method",
 | 
			
		||||
                    "thumbnail_type",
 | 
			
		||||
                    "thumbnail_length",
 | 
			
		||||
                ),
 | 
			
		||||
                desc="get_local_media_thumbnails",
 | 
			
		||||
            ),
 | 
			
		||||
            desc="get_local_media_thumbnails",
 | 
			
		||||
        )
 | 
			
		||||
        return [
 | 
			
		||||
            ThumbnailInfo(
 | 
			
		||||
                width=row["thumbnail_width"],
 | 
			
		||||
                height=row["thumbnail_height"],
 | 
			
		||||
                method=row["thumbnail_method"],
 | 
			
		||||
                type=row["thumbnail_type"],
 | 
			
		||||
                length=row["thumbnail_length"],
 | 
			
		||||
                width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
 | 
			
		||||
            )
 | 
			
		||||
            for row in rows
 | 
			
		||||
        ]
 | 
			
		||||
| 
						 | 
				
			
			@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 | 
			
		|||
    async def get_remote_media_thumbnails(
 | 
			
		||||
        self, origin: str, media_id: str
 | 
			
		||||
    ) -> List[ThumbnailInfo]:
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            "remote_media_cache_thumbnails",
 | 
			
		||||
            {"media_origin": origin, "media_id": media_id},
 | 
			
		||||
            (
 | 
			
		||||
                "thumbnail_width",
 | 
			
		||||
                "thumbnail_height",
 | 
			
		||||
                "thumbnail_method",
 | 
			
		||||
                "thumbnail_type",
 | 
			
		||||
                "thumbnail_length",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[int, int, str, str, int]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "remote_media_cache_thumbnails",
 | 
			
		||||
                {"media_origin": origin, "media_id": media_id},
 | 
			
		||||
                (
 | 
			
		||||
                    "thumbnail_width",
 | 
			
		||||
                    "thumbnail_height",
 | 
			
		||||
                    "thumbnail_method",
 | 
			
		||||
                    "thumbnail_type",
 | 
			
		||||
                    "thumbnail_length",
 | 
			
		||||
                ),
 | 
			
		||||
                desc="get_remote_media_thumbnails",
 | 
			
		||||
            ),
 | 
			
		||||
            desc="get_remote_media_thumbnails",
 | 
			
		||||
        )
 | 
			
		||||
        return [
 | 
			
		||||
            ThumbnailInfo(
 | 
			
		||||
                width=row["thumbnail_width"],
 | 
			
		||||
                height=row["thumbnail_height"],
 | 
			
		||||
                method=row["thumbnail_method"],
 | 
			
		||||
                type=row["thumbnail_type"],
 | 
			
		||||
                length=row["thumbnail_length"],
 | 
			
		||||
                width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
 | 
			
		||||
            )
 | 
			
		||||
            for row in rows
 | 
			
		||||
        ]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -179,46 +179,44 @@ class PushRulesWorkerStore(
 | 
			
		|||
 | 
			
		||||
    @cached(max_entries=5000)
 | 
			
		||||
    async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="push_rules",
 | 
			
		||||
            keyvalues={"user_name": user_id},
 | 
			
		||||
            retcols=(
 | 
			
		||||
                "user_name",
 | 
			
		||||
                "rule_id",
 | 
			
		||||
                "priority_class",
 | 
			
		||||
                "priority",
 | 
			
		||||
                "conditions",
 | 
			
		||||
                "actions",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, int, int, str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="push_rules",
 | 
			
		||||
                keyvalues={"user_name": user_id},
 | 
			
		||||
                retcols=(
 | 
			
		||||
                    "rule_id",
 | 
			
		||||
                    "priority_class",
 | 
			
		||||
                    "priority",
 | 
			
		||||
                    "conditions",
 | 
			
		||||
                    "actions",
 | 
			
		||||
                ),
 | 
			
		||||
                desc="get_push_rules_for_user",
 | 
			
		||||
            ),
 | 
			
		||||
            desc="get_push_rules_for_user",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 | 
			
		||||
        # Sort by highest priority_class, then highest priority.
 | 
			
		||||
        rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
 | 
			
		||||
 | 
			
		||||
        enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 | 
			
		||||
 | 
			
		||||
        return _load_rules(
 | 
			
		||||
            [
 | 
			
		||||
                (
 | 
			
		||||
                    row["rule_id"],
 | 
			
		||||
                    row["priority_class"],
 | 
			
		||||
                    row["conditions"],
 | 
			
		||||
                    row["actions"],
 | 
			
		||||
                )
 | 
			
		||||
                for row in rows
 | 
			
		||||
            ],
 | 
			
		||||
            [(row[0], row[1], row[3], row[4]) for row in rows],
 | 
			
		||||
            enabled_map,
 | 
			
		||||
            self.hs.config.experimental,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
 | 
			
		||||
        results = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="push_rules_enable",
 | 
			
		||||
            keyvalues={"user_name": user_id},
 | 
			
		||||
            retcols=("rule_id", "enabled"),
 | 
			
		||||
            desc="get_push_rules_enabled_for_user",
 | 
			
		||||
        results = cast(
 | 
			
		||||
            List[Tuple[str, Optional[Union[int, bool]]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="push_rules_enable",
 | 
			
		||||
                keyvalues={"user_name": user_id},
 | 
			
		||||
                retcols=("rule_id", "enabled"),
 | 
			
		||||
                desc="get_push_rules_enabled_for_user",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {r["rule_id"]: bool(r["enabled"]) for r in results}
 | 
			
		||||
        return {r[0]: bool(r[1]) for r in results}
 | 
			
		||||
 | 
			
		||||
    async def have_push_rules_changed_for_user(
 | 
			
		||||
        self, user_id: str, last_id: int
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
 | 
			
		|||
    async def get_throttle_params_by_room(
 | 
			
		||||
        self, pusher_id: int
 | 
			
		||||
    ) -> Dict[str, ThrottleParams]:
 | 
			
		||||
        res = await self.db_pool.simple_select_list(
 | 
			
		||||
            "pusher_throttle",
 | 
			
		||||
            {"pusher": pusher_id},
 | 
			
		||||
            ["room_id", "last_sent_ts", "throttle_ms"],
 | 
			
		||||
            desc="get_throttle_params_by_room",
 | 
			
		||||
        res = cast(
 | 
			
		||||
            List[Tuple[str, Optional[int], Optional[int]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "pusher_throttle",
 | 
			
		||||
                {"pusher": pusher_id},
 | 
			
		||||
                ["room_id", "last_sent_ts", "throttle_ms"],
 | 
			
		||||
                desc="get_throttle_params_by_room",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        params_by_room = {}
 | 
			
		||||
        for row in res:
 | 
			
		||||
            params_by_room[row["room_id"]] = ThrottleParams(
 | 
			
		||||
                row["last_sent_ts"],
 | 
			
		||||
                row["throttle_ms"],
 | 
			
		||||
        for room_id, last_sent_ts, throttle_ms in res:
 | 
			
		||||
            params_by_room[room_id] = ThrottleParams(
 | 
			
		||||
                last_sent_ts or 0, throttle_ms or 0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return params_by_room
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            Tuples of (auth_provider, external_id)
 | 
			
		||||
        """
 | 
			
		||||
        res = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="user_external_ids",
 | 
			
		||||
            keyvalues={"user_id": mxid},
 | 
			
		||||
            retcols=("auth_provider", "external_id"),
 | 
			
		||||
            desc="get_external_ids_by_user",
 | 
			
		||||
        return cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="user_external_ids",
 | 
			
		||||
                keyvalues={"user_id": mxid},
 | 
			
		||||
                retcols=("auth_provider", "external_id"),
 | 
			
		||||
                desc="get_external_ids_by_user",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return [(r["auth_provider"], r["external_id"]) for r in res]
 | 
			
		||||
 | 
			
		||||
    async def count_all_users(self) -> int:
 | 
			
		||||
        """Counts all users registered on the homeserver."""
 | 
			
		||||
| 
						 | 
				
			
			@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
 | 
			
		||||
        results = await self.db_pool.simple_select_list(
 | 
			
		||||
            "user_threepids",
 | 
			
		||||
            keyvalues={"user_id": user_id},
 | 
			
		||||
            retcols=["medium", "address", "validated_at", "added_at"],
 | 
			
		||||
            desc="user_get_threepids",
 | 
			
		||||
        results = cast(
 | 
			
		||||
            List[Tuple[str, str, int, int]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "user_threepids",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["medium", "address", "validated_at", "added_at"],
 | 
			
		||||
                desc="user_get_threepids",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return [ThreepidResult(**r) for r in results]
 | 
			
		||||
        return [
 | 
			
		||||
            ThreepidResult(
 | 
			
		||||
                medium=r[0],
 | 
			
		||||
                address=r[1],
 | 
			
		||||
                validated_at=r[2],
 | 
			
		||||
                added_at=r[3],
 | 
			
		||||
            )
 | 
			
		||||
            for r in results
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    async def user_delete_threepid(
 | 
			
		||||
        self, user_id: str, medium: str, address: str
 | 
			
		||||
| 
						 | 
				
			
			@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		|||
            desc="add_user_bound_threepid",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
 | 
			
		||||
    async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
 | 
			
		||||
        """Get the threepids that a user has bound to an identity server through the homeserver
 | 
			
		||||
        The homeserver remembers where binds to an identity server occurred. Using this
 | 
			
		||||
        method can retrieve those threepids.
 | 
			
		||||
| 
						 | 
				
			
			@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		|||
            user_id: The ID of the user to retrieve threepids for
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List of dictionaries containing the following keys:
 | 
			
		||||
                medium (str): The medium of the threepid (e.g "email")
 | 
			
		||||
                address (str): The address of the threepid (e.g "bob@example.com")
 | 
			
		||||
            List of tuples of two strings:
 | 
			
		||||
                medium: The medium of the threepid (e.g "email")
 | 
			
		||||
                address: The address of the threepid (e.g "bob@example.com")
 | 
			
		||||
        """
 | 
			
		||||
        return await self.db_pool.simple_select_list(
 | 
			
		||||
            table="user_threepid_id_server",
 | 
			
		||||
            keyvalues={"user_id": user_id},
 | 
			
		||||
            retcols=["medium", "address"],
 | 
			
		||||
            desc="user_get_bound_threepids",
 | 
			
		||||
        return cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="user_threepid_id_server",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["medium", "address"],
 | 
			
		||||
                desc="user_get_bound_threepids",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def remove_user_bound_threepid(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
 | 
			
		|||
        def get_all_relation_ids_for_event_txn(
 | 
			
		||||
            txn: LoggingTransaction,
 | 
			
		||||
        ) -> List[str]:
 | 
			
		||||
            rows = self.db_pool.simple_select_list_txn(
 | 
			
		||||
                txn=txn,
 | 
			
		||||
                table="event_relations",
 | 
			
		||||
                keyvalues={"relates_to_id": event_id},
 | 
			
		||||
                retcols=["event_id"],
 | 
			
		||||
            rows = cast(
 | 
			
		||||
                List[Tuple[str]],
 | 
			
		||||
                self.db_pool.simple_select_list_txn(
 | 
			
		||||
                    txn=txn,
 | 
			
		||||
                    table="event_relations",
 | 
			
		||||
                    keyvalues={"relates_to_id": event_id},
 | 
			
		||||
                    retcols=["event_id"],
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return [row["event_id"] for row in rows]
 | 
			
		||||
            return [row[0] for row in rows]
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction(
 | 
			
		||||
            desc="get_all_relation_ids_for_event",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		|||
        """
 | 
			
		||||
        room_servers: Dict[str, PartialStateResyncInfo] = {}
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="partial_state_rooms",
 | 
			
		||||
            keyvalues={},
 | 
			
		||||
            retcols=("room_id", "joined_via"),
 | 
			
		||||
            desc="get_server_which_served_partial_join",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="partial_state_rooms",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=("room_id", "joined_via"),
 | 
			
		||||
                desc="get_server_which_served_partial_join",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            room_id = row["room_id"]
 | 
			
		||||
            joined_via = row["joined_via"]
 | 
			
		||||
        for room_id, joined_via in rows:
 | 
			
		||||
            room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            "partial_state_rooms_servers",
 | 
			
		||||
            keyvalues=None,
 | 
			
		||||
            retcols=("room_id", "server_name"),
 | 
			
		||||
            desc="get_partial_state_rooms",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "partial_state_rooms_servers",
 | 
			
		||||
                keyvalues=None,
 | 
			
		||||
                retcols=("room_id", "server_name"),
 | 
			
		||||
                desc="get_partial_state_rooms",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            room_id = row["room_id"]
 | 
			
		||||
            server_name = row["server_name"]
 | 
			
		||||
        for room_id, server_name in rows:
 | 
			
		||||
            entry = room_servers.get(room_id)
 | 
			
		||||
            if entry is None:
 | 
			
		||||
                # There is a foreign key constraint which enforces that every room_id in
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 | 
			
		|||
        for fully-joined rooms.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            "current_state_events",
 | 
			
		||||
            keyvalues={"room_id": room_id},
 | 
			
		||||
            retcols=("event_id", "membership"),
 | 
			
		||||
            desc="has_completed_background_updates",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, Optional[str]]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "current_state_events",
 | 
			
		||||
                keyvalues={"room_id": room_id},
 | 
			
		||||
                retcols=("event_id", "membership"),
 | 
			
		||||
                desc="has_completed_background_updates",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {row["event_id"]: row["membership"] for row in rows}
 | 
			
		||||
        return dict(rows)
 | 
			
		||||
 | 
			
		||||
    # TODO This returns a mutable object, which is generally confusing when using a cache.
 | 
			
		||||
    @cached(max_entries=10000)  # type: ignore[synapse-@cached-mutable]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
 | 
			
		|||
            tag content.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            room_tags = tags_by_room.setdefault(row["room_id"], {})
 | 
			
		||||
            room_tags[row["tag"]] = db_to_json(row["content"])
 | 
			
		||||
        for room_id, tag, content in rows:
 | 
			
		||||
            room_tags = tags_by_room.setdefault(room_id, {})
 | 
			
		||||
            room_tags[tag] = db_to_json(content)
 | 
			
		||||
        return tags_by_room
 | 
			
		||||
 | 
			
		||||
    async def get_all_updated_tags(
 | 
			
		||||
| 
						 | 
				
			
			@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            A mapping of tags to tag content.
 | 
			
		||||
        """
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="room_tags",
 | 
			
		||||
            keyvalues={"user_id": user_id, "room_id": room_id},
 | 
			
		||||
            retcols=("tag", "content"),
 | 
			
		||||
            desc="get_tags_for_room",
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="room_tags",
 | 
			
		||||
                keyvalues={"user_id": user_id, "room_id": room_id},
 | 
			
		||||
                retcols=("tag", "content"),
 | 
			
		||||
                desc="get_tags_for_room",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {row["tag"]: db_to_json(row["content"]) for row in rows}
 | 
			
		||||
        return {tag: db_to_json(content) for tag, content in rows}
 | 
			
		||||
 | 
			
		||||
    async def add_tag_to_room(
 | 
			
		||||
        self, user_id: str, room_id: str, tag: str, content: JsonDict
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
 | 
			
		|||
            that auth-type.
 | 
			
		||||
        """
 | 
			
		||||
        results = {}
 | 
			
		||||
        for row in await self.db_pool.simple_select_list(
 | 
			
		||||
            table="ui_auth_sessions_credentials",
 | 
			
		||||
            keyvalues={"session_id": session_id},
 | 
			
		||||
            retcols=("stage_type", "result"),
 | 
			
		||||
            desc="get_completed_ui_auth_stages",
 | 
			
		||||
        ):
 | 
			
		||||
            results[row["stage_type"]] = db_to_json(row["result"])
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="ui_auth_sessions_credentials",
 | 
			
		||||
                keyvalues={"session_id": session_id},
 | 
			
		||||
                retcols=("stage_type", "result"),
 | 
			
		||||
                desc="get_completed_ui_auth_stages",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        for stage_type, result in rows:
 | 
			
		||||
            results[stage_type] = db_to_json(result)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            List of user_agent/ip pairs
 | 
			
		||||
        """
 | 
			
		||||
        rows = await self.db_pool.simple_select_list(
 | 
			
		||||
            table="ui_auth_sessions_ips",
 | 
			
		||||
            keyvalues={"session_id": session_id},
 | 
			
		||||
            retcols=("user_agent", "ip"),
 | 
			
		||||
            desc="get_user_agents_ips_to_ui_auth_session",
 | 
			
		||||
        return cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.db_pool.simple_select_list(
 | 
			
		||||
                table="ui_auth_sessions_ips",
 | 
			
		||||
                keyvalues={"session_id": session_id},
 | 
			
		||||
                retcols=("user_agent", "ip"),
 | 
			
		||||
                desc="get_user_agents_ips_to_ui_auth_session",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return [(row["user_agent"], row["ip"]) for row in rows]
 | 
			
		||||
 | 
			
		||||
    async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 | 
			
		|||
            if not prev_group:
 | 
			
		||||
                return _GetStateGroupDelta(None, None)
 | 
			
		||||
 | 
			
		||||
            delta_ids = self.db_pool.simple_select_list_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="state_groups_state",
 | 
			
		||||
                keyvalues={"state_group": state_group},
 | 
			
		||||
                retcols=("type", "state_key", "event_id"),
 | 
			
		||||
            delta_ids = cast(
 | 
			
		||||
                List[Tuple[str, str, str]],
 | 
			
		||||
                self.db_pool.simple_select_list_txn(
 | 
			
		||||
                    txn,
 | 
			
		||||
                    table="state_groups_state",
 | 
			
		||||
                    keyvalues={"state_group": state_group},
 | 
			
		||||
                    retcols=("type", "state_key", "event_id"),
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return _GetStateGroupDelta(
 | 
			
		||||
                prev_group,
 | 
			
		||||
                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
 | 
			
		||||
                {
 | 
			
		||||
                    (event_type, state_key): event_id
 | 
			
		||||
                    for event_type, state_key, event_id in delta_ids
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return await self.db_pool.runInteraction(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 | 
			
		|||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def get_all_room_state(self) -> List[Dict[str, Any]]:
 | 
			
		||||
        return await self.store.db_pool.simple_select_list(
 | 
			
		||||
            "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
 | 
			
		||||
    async def get_all_room_state(self) -> List[Optional[str]]:
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[Optional[str]]],
 | 
			
		||||
            await self.store.db_pool.simple_select_list(
 | 
			
		||||
                "room_stats_state", None, retcols=("topic",)
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return [r[0] for r in rows]
 | 
			
		||||
 | 
			
		||||
    def _get_current_stats(
 | 
			
		||||
        self, stats_type: str, stat_id: str
 | 
			
		||||
| 
						 | 
				
			
			@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 | 
			
		|||
        r = self.get_success(self.get_all_room_state())
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(r), 1)
 | 
			
		||||
        self.assertEqual(r[0]["topic"], "foo")
 | 
			
		||||
        self.assertEqual(r[0], "foo")
 | 
			
		||||
 | 
			
		||||
    def test_create_user(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
 | 
			
		|||
            if expected_row is not None:
 | 
			
		||||
                columns += expected_row.keys()
 | 
			
		||||
 | 
			
		||||
            rows = self.get_success(
 | 
			
		||||
            row_tuples = self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table=table,
 | 
			
		||||
                    keyvalues={
 | 
			
		||||
| 
						 | 
				
			
			@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
 | 
			
		|||
 | 
			
		||||
            if expected_row is not None:
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    len(rows),
 | 
			
		||||
                    len(row_tuples),
 | 
			
		||||
                    1,
 | 
			
		||||
                    f"Background update did not leave behind latest receipt in {table}",
 | 
			
		||||
                )
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    rows[0],
 | 
			
		||||
                    {
 | 
			
		||||
                        "room_id": room_id,
 | 
			
		||||
                        "receipt_type": receipt_type,
 | 
			
		||||
                        "user_id": user_id,
 | 
			
		||||
                        **expected_row,
 | 
			
		||||
                    },
 | 
			
		||||
                    row_tuples[0],
 | 
			
		||||
                    (
 | 
			
		||||
                        room_id,
 | 
			
		||||
                        receipt_type,
 | 
			
		||||
                        user_id,
 | 
			
		||||
                        *expected_row.values(),
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    len(rows),
 | 
			
		||||
                    len(row_tuples),
 | 
			
		||||
                    0,
 | 
			
		||||
                    f"Background update did not remove all duplicate receipts from {table}",
 | 
			
		||||
                )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import secrets
 | 
			
		||||
from typing import Generator, Tuple
 | 
			
		||||
from typing import Generator, List, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
 | 
			
		||||
        res = self.get_success(
 | 
			
		||||
            self.storage.db_pool.simple_select_list(
 | 
			
		||||
                self.table_name, None, ["id, username, value"]
 | 
			
		||||
            )
 | 
			
		||||
        yield from cast(
 | 
			
		||||
            List[Tuple[int, str, str]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.storage.db_pool.simple_select_list(
 | 
			
		||||
                    self.table_name, None, ["id, username, value"]
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for i in res:
 | 
			
		||||
            yield (i["id"], i["username"], i["value"])
 | 
			
		||||
 | 
			
		||||
    def test_upsert_many(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Upsert_many will perform the upsert operation across a batch of data.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import logging
 | 
			
		||||
from typing import List, Tuple, cast
 | 
			
		||||
from unittest.mock import AsyncMock, Mock
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
| 
						 | 
				
			
			@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
            self.wait_for_background_updates()
 | 
			
		||||
 | 
			
		||||
        # Check the correct values are in the new table.
 | 
			
		||||
        rows = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="test_constraint",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=("a", "b"),
 | 
			
		||||
            )
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[int, int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="test_constraint",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=("a", "b"),
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
 | 
			
		||||
        self.assertCountEqual(rows, [(1, 1), (3, 3)])
 | 
			
		||||
 | 
			
		||||
        # And check that invalid rows get correctly rejected.
 | 
			
		||||
        self.get_failure(
 | 
			
		||||
| 
						 | 
				
			
			@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
            self.wait_for_background_updates()
 | 
			
		||||
 | 
			
		||||
        # Check the correct values are in the new table.
 | 
			
		||||
        rows = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="test_constraint",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=("a", "b"),
 | 
			
		||||
            )
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[int, int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="test_constraint",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=("a", "b"),
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
 | 
			
		||||
        self.assertCountEqual(rows, [(1, 1), (3, 3)])
 | 
			
		||||
 | 
			
		||||
        # And check that invalid rows get correctly rejected.
 | 
			
		||||
        self.get_failure(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 | 
			
		|||
    @defer.inlineCallbacks
 | 
			
		||||
    def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
 | 
			
		||||
        self.mock_txn.rowcount = 3
 | 
			
		||||
        self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
 | 
			
		||||
        self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
 | 
			
		||||
        self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 | 
			
		||||
 | 
			
		||||
        ret = yield defer.ensureDeferred(
 | 
			
		||||
| 
						 | 
				
			
			@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 | 
			
		|||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
 | 
			
		||||
        self.assertEqual([(1,), (2,), (3,)], ret)
 | 
			
		||||
        self.mock_txn.execute.assert_called_with(
 | 
			
		||||
            "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import Any, Dict
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, cast
 | 
			
		||||
from unittest.mock import AsyncMock
 | 
			
		||||
 | 
			
		||||
from parameterized import parameterized
 | 
			
		||||
| 
						 | 
				
			
			@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        self.reactor.advance(200)
 | 
			
		||||
        self.pump(0)
 | 
			
		||||
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
 | 
			
		||||
                desc="get_user_ip_and_agents",
 | 
			
		||||
            )
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, Optional[str], int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=[
 | 
			
		||||
                        "access_token",
 | 
			
		||||
                        "ip",
 | 
			
		||||
                        "user_agent",
 | 
			
		||||
                        "device_id",
 | 
			
		||||
                        "last_seen",
 | 
			
		||||
                    ],
 | 
			
		||||
                    desc="get_user_ip_and_agents",
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": "access_token",
 | 
			
		||||
                    "ip": "ip",
 | 
			
		||||
                    "user_agent": "user_agent",
 | 
			
		||||
                    "device_id": None,
 | 
			
		||||
                    "last_seen": 12345678000,
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            result, [("access_token", "ip", "user_agent", None, 12345678000)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Add another & trigger the storage loop
 | 
			
		||||
| 
						 | 
				
			
			@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        self.reactor.advance(10)
 | 
			
		||||
        self.pump(0)
 | 
			
		||||
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
 | 
			
		||||
                desc="get_user_ip_and_agents",
 | 
			
		||||
            )
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, Optional[str], int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=[
 | 
			
		||||
                        "access_token",
 | 
			
		||||
                        "ip",
 | 
			
		||||
                        "user_agent",
 | 
			
		||||
                        "device_id",
 | 
			
		||||
                        "last_seen",
 | 
			
		||||
                    ],
 | 
			
		||||
                    desc="get_user_ip_and_agents",
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # Only one result, has been upserted.
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": "access_token",
 | 
			
		||||
                    "ip": "ip",
 | 
			
		||||
                    "user_agent": "user_agent",
 | 
			
		||||
                    "device_id": None,
 | 
			
		||||
                    "last_seen": 12345878000,
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            result, [("access_token", "ip", "user_agent", None, 12345878000)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @parameterized.expand([(False,), (True,)])
 | 
			
		||||
| 
						 | 
				
			
			@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
            self.reactor.advance(10)
 | 
			
		||||
        else:
 | 
			
		||||
            # Check that the new IP and user agent has not been stored yet
 | 
			
		||||
            db_result = self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="devices",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
 | 
			
		||||
            db_result = cast(
 | 
			
		||||
                List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
 | 
			
		||||
                self.get_success(
 | 
			
		||||
                    self.store.db_pool.simple_select_list(
 | 
			
		||||
                        table="devices",
 | 
			
		||||
                        keyvalues={},
 | 
			
		||||
                        retcols=(
 | 
			
		||||
                            "user_id",
 | 
			
		||||
                            "ip",
 | 
			
		||||
                            "user_agent",
 | 
			
		||||
                            "device_id",
 | 
			
		||||
                            "last_seen",
 | 
			
		||||
                        ),
 | 
			
		||||
                    ),
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                db_result,
 | 
			
		||||
                [
 | 
			
		||||
                    {
 | 
			
		||||
                        "user_id": user_id,
 | 
			
		||||
                        "device_id": device_id,
 | 
			
		||||
                        "ip": None,
 | 
			
		||||
                        "user_agent": None,
 | 
			
		||||
                        "last_seen": None,
 | 
			
		||||
                    },
 | 
			
		||||
                ],
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
 | 
			
		||||
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.get_last_client_ip_by_device(user_id, device_id)
 | 
			
		||||
| 
						 | 
				
			
			@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        # Check that the new IP and user agent has not been stored yet
 | 
			
		||||
        db_result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="devices",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
 | 
			
		||||
        db_result = cast(
 | 
			
		||||
            List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="devices",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertCountEqual(
 | 
			
		||||
            db_result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "user_id": user_id,
 | 
			
		||||
                    "device_id": device_id_1,
 | 
			
		||||
                    "ip": "ip_1",
 | 
			
		||||
                    "user_agent": "user_agent_1",
 | 
			
		||||
                    "last_seen": 12345678000,
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "user_id": user_id,
 | 
			
		||||
                    "device_id": device_id_2,
 | 
			
		||||
                    "ip": "ip_2",
 | 
			
		||||
                    "user_agent": "user_agent_2",
 | 
			
		||||
                    "last_seen": 12345678000,
 | 
			
		||||
                },
 | 
			
		||||
                (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
 | 
			
		||||
                (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        # Check that the new IP and user agent has not been stored yet
 | 
			
		||||
        db_result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=("access_token", "ip", "user_agent", "last_seen"),
 | 
			
		||||
        db_result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=("access_token", "ip", "user_agent", "last_seen"),
 | 
			
		||||
                ),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            db_result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": "access_token",
 | 
			
		||||
                    "ip": "ip_1",
 | 
			
		||||
                    "user_agent": "user_agent_1",
 | 
			
		||||
                    "last_seen": 12345678000,
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": "access_token",
 | 
			
		||||
                    "ip": "ip_2",
 | 
			
		||||
                    "user_agent": "user_agent_2",
 | 
			
		||||
                    "last_seen": 12345678000,
 | 
			
		||||
                },
 | 
			
		||||
                ("access_token", "ip_1", "user_agent_1", 12345678000),
 | 
			
		||||
                ("access_token", "ip_2", "user_agent_2", 12345678000),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        self.reactor.advance(200)
 | 
			
		||||
 | 
			
		||||
        # We should see that in the DB
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
 | 
			
		||||
                desc="get_user_ip_and_agents",
 | 
			
		||||
            )
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, Optional[str], int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=[
 | 
			
		||||
                        "access_token",
 | 
			
		||||
                        "ip",
 | 
			
		||||
                        "user_agent",
 | 
			
		||||
                        "device_id",
 | 
			
		||||
                        "last_seen",
 | 
			
		||||
                    ],
 | 
			
		||||
                    desc="get_user_ip_and_agents",
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": "access_token",
 | 
			
		||||
                    "ip": "ip",
 | 
			
		||||
                    "user_agent": "user_agent",
 | 
			
		||||
                    "device_id": device_id,
 | 
			
		||||
                    "last_seen": 0,
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            [("access_token", "ip", "user_agent", device_id, 0)],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Now advance by a couple of months
 | 
			
		||||
        self.reactor.advance(60 * 24 * 60 * 60)
 | 
			
		||||
 | 
			
		||||
        # We should get no results.
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={"user_id": user_id},
 | 
			
		||||
                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
 | 
			
		||||
                desc="get_user_ip_and_agents",
 | 
			
		||||
            )
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, Optional[str], int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={"user_id": user_id},
 | 
			
		||||
                    retcols=[
 | 
			
		||||
                        "access_token",
 | 
			
		||||
                        "ip",
 | 
			
		||||
                        "user_agent",
 | 
			
		||||
                        "device_id",
 | 
			
		||||
                        "last_seen",
 | 
			
		||||
                    ],
 | 
			
		||||
                    desc="get_user_ip_and_agents",
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(result, [])
 | 
			
		||||
| 
						 | 
				
			
			@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
        self.reactor.advance(200)
 | 
			
		||||
 | 
			
		||||
        # We should see that in the DB
 | 
			
		||||
        result = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="user_ips",
 | 
			
		||||
                keyvalues={},
 | 
			
		||||
                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
 | 
			
		||||
                desc="get_user_ip_and_agents",
 | 
			
		||||
            )
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str, str, str, Optional[str], int]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="user_ips",
 | 
			
		||||
                    keyvalues={},
 | 
			
		||||
                    retcols=[
 | 
			
		||||
                        "access_token",
 | 
			
		||||
                        "ip",
 | 
			
		||||
                        "user_agent",
 | 
			
		||||
                        "device_id",
 | 
			
		||||
                        "last_seen",
 | 
			
		||||
                    ],
 | 
			
		||||
                    desc="get_user_ip_and_agents",
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # ensure user1 is filtered out
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            result,
 | 
			
		||||
            [
 | 
			
		||||
                {
 | 
			
		||||
                    "access_token": access_token2,
 | 
			
		||||
                    "ip": "ip",
 | 
			
		||||
                    "user_agent": "user_agent",
 | 
			
		||||
                    "device_id": device_id2,
 | 
			
		||||
                    "last_seen": 0,
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,8 @@
 | 
			
		|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
from typing import List, Optional, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from twisted.test.proto_helpers import MemoryReactor
 | 
			
		||||
 | 
			
		||||
from synapse.api.constants import Membership
 | 
			
		||||
| 
						 | 
				
			
			@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
    def test__null_byte_in_display_name_properly_handled(self) -> None:
 | 
			
		||||
        room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 | 
			
		||||
 | 
			
		||||
        res = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                "room_memberships",
 | 
			
		||||
                {"user_id": "@alice:test"},
 | 
			
		||||
                ["display_name", "event_id"],
 | 
			
		||||
            )
 | 
			
		||||
        res = cast(
 | 
			
		||||
            List[Tuple[Optional[str], str]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    "room_memberships",
 | 
			
		||||
                    {"user_id": "@alice:test"},
 | 
			
		||||
                    ["display_name", "event_id"],
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # Check that we only got one result back
 | 
			
		||||
        self.assertEqual(len(res), 1)
 | 
			
		||||
 | 
			
		||||
        # Check that alice's display name is "alice"
 | 
			
		||||
        self.assertEqual(res[0]["display_name"], "alice")
 | 
			
		||||
        self.assertEqual(res[0][0], "alice")
 | 
			
		||||
 | 
			
		||||
        # Grab the event_id to use later
 | 
			
		||||
        event_id = res[0]["event_id"]
 | 
			
		||||
        event_id = res[0][1]
 | 
			
		||||
 | 
			
		||||
        # Create a profile with the offending null byte in the display name
 | 
			
		||||
        new_profile = {"displayname": "ali\u0000ce"}
 | 
			
		||||
| 
						 | 
				
			
			@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 | 
			
		|||
            tok=self.t_alice,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        res2 = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                "room_memberships",
 | 
			
		||||
                {"user_id": "@alice:test"},
 | 
			
		||||
                ["display_name", "event_id"],
 | 
			
		||||
            )
 | 
			
		||||
        res2 = cast(
 | 
			
		||||
            List[Tuple[Optional[str], str]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    "room_memberships",
 | 
			
		||||
                    {"user_id": "@alice:test"},
 | 
			
		||||
                    ["display_name", "event_id"],
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # Check that we only have two results
 | 
			
		||||
        self.assertEqual(len(res2), 2)
 | 
			
		||||
 | 
			
		||||
        # Filter out the previous event using the event_id we grabbed above
 | 
			
		||||
        row = [row for row in res2 if row["event_id"] != event_id]
 | 
			
		||||
        row = [row for row in res2 if row[1] != event_id]
 | 
			
		||||
 | 
			
		||||
        # Check that alice's display name is now None
 | 
			
		||||
        self.assertEqual(row[0]["display_name"], None)
 | 
			
		||||
        self.assertIsNone(row[0][0])
 | 
			
		||||
 | 
			
		||||
    def test_room_is_locally_forgotten(self) -> None:
 | 
			
		||||
        """Test that when the last local user has forgotten a room it is known as forgotten."""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import List, Tuple, cast
 | 
			
		||||
 | 
			
		||||
from immutabledict import immutabledict
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        # check that only state events are in state_groups, and all state events are in state_groups
 | 
			
		||||
        res = self.get_success(
 | 
			
		||||
            self.store.db_pool.simple_select_list(
 | 
			
		||||
                table="state_groups",
 | 
			
		||||
                keyvalues=None,
 | 
			
		||||
                retcols=("event_id",),
 | 
			
		||||
            )
 | 
			
		||||
        res = cast(
 | 
			
		||||
            List[Tuple[str]],
 | 
			
		||||
            self.get_success(
 | 
			
		||||
                self.store.db_pool.simple_select_list(
 | 
			
		||||
                    table="state_groups",
 | 
			
		||||
                    keyvalues=None,
 | 
			
		||||
                    retcols=("event_id",),
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        events = []
 | 
			
		||||
        for result in res:
 | 
			
		||||
            self.assertNotIn(event3.event_id, result)
 | 
			
		||||
            events.append(result.get("event_id"))
 | 
			
		||||
            self.assertNotIn(event3.event_id, result)  # XXX
 | 
			
		||||
            events.append(result[0])
 | 
			
		||||
 | 
			
		||||
        for event, _ in processed_events_and_context:
 | 
			
		||||
            if event.is_state():
 | 
			
		||||
| 
						 | 
				
			
			@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
 | 
			
		|||
        # has an entry and prev event in state_group_edges
 | 
			
		||||
        for event, context in processed_events_and_context:
 | 
			
		||||
            if event.is_state():
 | 
			
		||||
                state = self.get_success(
 | 
			
		||||
                    self.store.db_pool.simple_select_list(
 | 
			
		||||
                        table="state_groups_state",
 | 
			
		||||
                        keyvalues={"state_group": context.state_group_after_event},
 | 
			
		||||
                        retcols=("type", "state_key"),
 | 
			
		||||
                    )
 | 
			
		||||
                state = cast(
 | 
			
		||||
                    List[Tuple[str, str]],
 | 
			
		||||
                    self.get_success(
 | 
			
		||||
                        self.store.db_pool.simple_select_list(
 | 
			
		||||
                            table="state_groups_state",
 | 
			
		||||
                            keyvalues={"state_group": context.state_group_after_event},
 | 
			
		||||
                            retcols=("type", "state_key"),
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
                self.assertEqual(event.type, state[0].get("type"))
 | 
			
		||||
                self.assertEqual(event.state_key, state[0].get("state_key"))
 | 
			
		||||
                self.assertEqual(event.type, state[0][0])
 | 
			
		||||
                self.assertEqual(event.state_key, state[0][1])
 | 
			
		||||
 | 
			
		||||
                groups = self.get_success(
 | 
			
		||||
                    self.store.db_pool.simple_select_list(
 | 
			
		||||
                        table="state_group_edges",
 | 
			
		||||
                        keyvalues={"state_group": str(context.state_group_after_event)},
 | 
			
		||||
                        retcols=("*",),
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    context.state_group_before_event, groups[0].get("prev_state_group")
 | 
			
		||||
                groups = cast(
 | 
			
		||||
                    List[Tuple[str]],
 | 
			
		||||
                    self.get_success(
 | 
			
		||||
                        self.store.db_pool.simple_select_list(
 | 
			
		||||
                            table="state_group_edges",
 | 
			
		||||
                            keyvalues={
 | 
			
		||||
                                "state_group": str(context.state_group_after_event)
 | 
			
		||||
                            },
 | 
			
		||||
                            retcols=("prev_state_group",),
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
                self.assertEqual(context.state_group_before_event, groups[0][0])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import re
 | 
			
		||||
from typing import Any, Dict, Set, Tuple
 | 
			
		||||
from typing import Any, Dict, List, Optional, Set, Tuple, cast
 | 
			
		||||
from unittest import mock
 | 
			
		||||
from unittest.mock import Mock, patch
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -62,14 +62,13 @@ class GetUserDirectoryTables:
 | 
			
		|||
        Returns a list of tuples (user_id, room_id) where room_id is public and
 | 
			
		||||
        contains the user with the given id.
 | 
			
		||||
        """
 | 
			
		||||
        r = await self.store.db_pool.simple_select_list(
 | 
			
		||||
            "users_in_public_rooms", None, ("user_id", "room_id")
 | 
			
		||||
        r = cast(
 | 
			
		||||
            List[Tuple[str, str]],
 | 
			
		||||
            await self.store.db_pool.simple_select_list(
 | 
			
		||||
                "users_in_public_rooms", None, ("user_id", "room_id")
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        retval = set()
 | 
			
		||||
        for i in r:
 | 
			
		||||
            retval.add((i["user_id"], i["room_id"]))
 | 
			
		||||
        return retval
 | 
			
		||||
        return set(r)
 | 
			
		||||
 | 
			
		||||
    async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
 | 
			
		||||
        """Fetch the entire `users_who_share_private_rooms` table.
 | 
			
		||||
| 
						 | 
				
			
			@ -78,27 +77,30 @@ class GetUserDirectoryTables:
 | 
			
		|||
        to the rows of `users_who_share_private_rooms`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        rows = await self.store.db_pool.simple_select_list(
 | 
			
		||||
            "users_who_share_private_rooms",
 | 
			
		||||
            None,
 | 
			
		||||
            ["user_id", "other_user_id", "room_id"],
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, str, str]],
 | 
			
		||||
            await self.store.db_pool.simple_select_list(
 | 
			
		||||
                "users_who_share_private_rooms",
 | 
			
		||||
                None,
 | 
			
		||||
                ["user_id", "other_user_id", "room_id"],
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        rv = set()
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
 | 
			
		||||
        return rv
 | 
			
		||||
        return set(rows)
 | 
			
		||||
 | 
			
		||||
    async def get_users_in_user_directory(self) -> Set[str]:
 | 
			
		||||
        """Fetch the set of users in the `user_directory` table.
 | 
			
		||||
 | 
			
		||||
        This is useful when checking we've correctly excluded users from the directory.
 | 
			
		||||
        """
 | 
			
		||||
        result = await self.store.db_pool.simple_select_list(
 | 
			
		||||
            "user_directory",
 | 
			
		||||
            None,
 | 
			
		||||
            ["user_id"],
 | 
			
		||||
        result = cast(
 | 
			
		||||
            List[Tuple[str]],
 | 
			
		||||
            await self.store.db_pool.simple_select_list(
 | 
			
		||||
                "user_directory",
 | 
			
		||||
                None,
 | 
			
		||||
                ["user_id"],
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {row["user_id"] for row in result}
 | 
			
		||||
        return {row[0] for row in result}
 | 
			
		||||
 | 
			
		||||
    async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
 | 
			
		||||
        """Fetch users and their profiles from the `user_directory` table.
 | 
			
		||||
| 
						 | 
				
			
			@ -107,16 +109,17 @@ class GetUserDirectoryTables:
 | 
			
		|||
        It's almost the entire contents of the `user_directory` table: the only
 | 
			
		||||
        thing missing is an unused room_id column.
 | 
			
		||||
        """
 | 
			
		||||
        rows = await self.store.db_pool.simple_select_list(
 | 
			
		||||
            "user_directory",
 | 
			
		||||
            None,
 | 
			
		||||
            ("user_id", "display_name", "avatar_url"),
 | 
			
		||||
        rows = cast(
 | 
			
		||||
            List[Tuple[str, Optional[str], Optional[str]]],
 | 
			
		||||
            await self.store.db_pool.simple_select_list(
 | 
			
		||||
                "user_directory",
 | 
			
		||||
                None,
 | 
			
		||||
                ("user_id", "display_name", "avatar_url"),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return {
 | 
			
		||||
            row["user_id"]: ProfileInfo(
 | 
			
		||||
                display_name=row["display_name"], avatar_url=row["avatar_url"]
 | 
			
		||||
            )
 | 
			
		||||
            for row in rows
 | 
			
		||||
            user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
 | 
			
		||||
            for user_id, display_name, avatar_url in rows
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    async def get_tables(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue