Remove remaining usage of cursor_to_dict. (#16564)

pull/16589/head
Patrick Cloke 2023-10-31 13:13:28 -04:00 committed by GitHub
parent c0ba319b22
commit cfb6d38c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 300 additions and 157 deletions

1
changelog.d/16564.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -283,7 +283,7 @@ class AdminHandler:
start, limit, user_id start, limit, user_id
) )
for media in media_ids: for media in media_ids:
writer.write_media_id(media["media_id"], media) writer.write_media_id(media.media_id, attr.asdict(media))
logger.info( logger.info(
"[%s] Written %d media_ids of %s", "[%s] Written %d media_ids of %s",

View File

@ -33,6 +33,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.storage.databases.main.room import LargestRoomStats
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -170,26 +171,24 @@ class RoomListHandler:
ignore_non_federatable=from_federation, ignore_non_federatable=from_federation,
) )
def build_room_entry(room: JsonDict) -> JsonDict: def build_room_entry(room: LargestRoomStats) -> JsonDict:
entry = { entry = {
"room_id": room["room_id"], "room_id": room.room_id,
"name": room["name"], "name": room.name,
"topic": room["topic"], "topic": room.topic,
"canonical_alias": room["canonical_alias"], "canonical_alias": room.canonical_alias,
"num_joined_members": room["joined_members"], "num_joined_members": room.joined_members,
"avatar_url": room["avatar"], "avatar_url": room.avatar,
"world_readable": room["history_visibility"] "world_readable": room.history_visibility
== HistoryVisibility.WORLD_READABLE, == HistoryVisibility.WORLD_READABLE,
"guest_can_join": room["guest_access"] == "can_join", "guest_can_join": room.guest_access == "can_join",
"join_rule": room["join_rules"], "join_rule": room.join_rules,
"room_type": room["room_type"], "room_type": room.room_type,
} }
# Filter out Nones rather omit the field altogether # Filter out Nones rather omit the field altogether
return {k: v for k, v in entry.items() if v is not None} return {k: v for k, v in entry.items() if v is not None}
results = [build_room_entry(r) for r in results]
response: JsonDict = {} response: JsonDict = {}
num_results = len(results) num_results = len(results)
if limit is not None: if limit is not None:
@ -212,33 +211,33 @@ class RoomListHandler:
# If there was a token given then we assume that there # If there was a token given then we assume that there
# must be previous results. # must be previous results.
response["prev_batch"] = RoomListNextBatch( response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"], last_joined_members=initial_entry.joined_members,
last_room_id=initial_entry["room_id"], last_room_id=initial_entry.room_id,
direction_is_forward=False, direction_is_forward=False,
).to_token() ).to_token()
if more_to_come: if more_to_come:
response["next_batch"] = RoomListNextBatch( response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"], last_joined_members=final_entry.joined_members,
last_room_id=final_entry["room_id"], last_room_id=final_entry.room_id,
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
else: else:
if has_batch_token: if has_batch_token:
response["next_batch"] = RoomListNextBatch( response["next_batch"] = RoomListNextBatch(
last_joined_members=final_entry["num_joined_members"], last_joined_members=final_entry.joined_members,
last_room_id=final_entry["room_id"], last_room_id=final_entry.room_id,
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
if more_to_come: if more_to_come:
response["prev_batch"] = RoomListNextBatch( response["prev_batch"] = RoomListNextBatch(
last_joined_members=initial_entry["num_joined_members"], last_joined_members=initial_entry.joined_members,
last_room_id=initial_entry["room_id"], last_room_id=initial_entry.room_id,
direction_is_forward=False, direction_is_forward=False,
).to_token() ).to_token()
response["chunk"] = results response["chunk"] = [build_room_entry(r) for r in results]
response["total_room_count_estimate"] = await self.store.count_public_rooms( response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple, network_tuple,

View File

@ -703,24 +703,24 @@ class RoomSummaryHandler:
# there should always be an entry # there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,) assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
entry = { entry: JsonDict = {
"room_id": stats["room_id"], "room_id": stats.room_id,
"name": stats["name"], "name": stats.name,
"topic": stats["topic"], "topic": stats.topic,
"canonical_alias": stats["canonical_alias"], "canonical_alias": stats.canonical_alias,
"num_joined_members": stats["joined_members"], "num_joined_members": stats.joined_members,
"avatar_url": stats["avatar"], "avatar_url": stats.avatar,
"join_rule": stats["join_rules"], "join_rule": stats.join_rules,
"world_readable": ( "world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE stats.history_visibility == HistoryVisibility.WORLD_READABLE
), ),
"guest_can_join": stats["guest_access"] == "can_join", "guest_can_join": stats.guest_access == "can_join",
"room_type": stats["room_type"], "room_type": stats.room_type,
} }
if self._msc3266_enabled: if self._msc3266_enabled:
entry["im.nheko.summary.version"] = stats["version"] entry["im.nheko.summary.version"] = stats.version
entry["im.nheko.summary.encryption"] = stats["encryption"] entry["im.nheko.summary.encryption"] = stats.encryption
# Federation requests need to provide additional information so the # Federation requests need to provide additional information so the
# requested server is able to filter the response appropriately. # requested server is able to filter the response appropriately.

View File

@ -17,6 +17,8 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
import attr
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet):
start, limit, user_id, order_by, direction start, limit, user_id, order_by, direction
) )
ret = {"media": media, "total": total} ret = {"media": [attr.asdict(m) for m in media], "total": total}
if (start + limit) < total: if (start + limit) < total:
ret["next_token"] = start + len(media) ret["next_token"] = start + len(media)
@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet):
) )
deleted_media, total = await self.media_repository.delete_local_media_ids( deleted_media, total = await self.media_repository.delete_local_media_ids(
[row["media_id"] for row in media] [m.media_id for m in media]
) )
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}

View File

@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
valid = parse_boolean(request, "valid") valid = parse_boolean(request, "valid")
token_list = await self.store.get_registration_tokens(valid) token_list = await self.store.get_registration_tokens(valid)
return HTTPStatus.OK, {"registration_tokens": token_list} return HTTPStatus.OK, {
"registration_tokens": [
{
"token": t[0],
"uses_allowed": t[1],
"pending": t[2],
"completed": t[3],
"expiry_time": t[4],
}
for t in token_list
]
}
class NewRegistrationTokenRestServlet(RestServlet): class NewRegistrationTokenRestServlet(RestServlet):

View File

@ -16,6 +16,8 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse from urllib import parse as urlparse
import attr
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet):
raise NotFoundError("Room not found") raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id) members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members) result = attr.asdict(ret)
ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) result["joined_local_devices"] = await self.store.count_devices_by_users(
members
)
result["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
return HTTPStatus.OK, ret return HTTPStatus.OK, result
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str

View File

@ -18,6 +18,8 @@ import secrets
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import Direction, UserTypes from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet):
) )
# If support for MSC3866 is not enabled, don't show the approval flag. # If support for MSC3866 is not enabled, don't show the approval flag.
filter = None
if not self._msc3866_enabled: if not self._msc3866_enabled:
for user in users:
del user["approved"]
ret = {"users": users, "total": total} def _filter(a: attr.Attribute) -> bool:
return a.name != "approved"
ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total}
if (start + limit) < total: if (start + limit) < total:
ret["next_token"] = str(start + len(users)) ret["next_token"] = str(start + len(users))

View File

@ -28,6 +28,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
cast,
) )
import attr import attr
@ -488,14 +489,14 @@ class BackgroundUpdater:
True if we have finished running all the background updates, otherwise False True if we have finished running all the background updates, otherwise False
""" """
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]:
txn.execute( txn.execute(
""" """
SELECT update_name, depends_on FROM background_updates SELECT update_name, depends_on FROM background_updates
ORDER BY ordering, update_name ORDER BY ordering, update_name
""" """
) )
return self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, Optional[str]]], txn.fetchall())
if not self._current_background_update: if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction( all_pending_updates = await self.db_pool.runInteraction(
@ -507,14 +508,13 @@ class BackgroundUpdater:
return True return True
# find the first update which isn't dependent on another one in the queue. # find the first update which isn't dependent on another one in the queue.
pending = {update["update_name"] for update in all_pending_updates} pending = {update_name for update_name, depends_on in all_pending_updates}
for upd in all_pending_updates: for update_name, depends_on in all_pending_updates:
depends_on = upd["depends_on"]
if not depends_on or depends_on not in pending: if not depends_on or depends_on not in pending:
break break
logger.info( logger.info(
"Not starting on bg update %s until %s is done", "Not starting on bg update %s until %s is done",
upd["update_name"], update_name,
depends_on, depends_on,
) )
else: else:
@ -524,7 +524,7 @@ class BackgroundUpdater:
"another: dependency cycle?" "another: dependency cycle?"
) )
self._current_background_update = upd["update_name"] self._current_background_update = update_name
# We have a background update to run, otherwise we would have returned # We have a background update to run, otherwise we would have returned
# early. # early.

View File

@ -18,7 +18,6 @@ import logging
import time import time
import types import types
from collections import defaultdict from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time from time import monotonic as monotonic_time
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -1042,20 +1041,6 @@ class DatabasePool:
self._db_pool.runWithConnection(inner_func, *args, **kwargs) self._db_pool.runWithConnection(inner_func, *args, **kwargs)
) )
@staticmethod
def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set. """Runs a single query for a result set.

View File

@ -17,6 +17,8 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage._base import make_in_list_sql_clause from synapse.storage._base import make_in_list_sql_clause
@ -28,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_domain_from_id from synapse.types import get_domain_from_id
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@ -82,6 +84,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPaginateResponse:
"""This is very similar to UserInfo, but not quite the same."""
name: str
user_type: Optional[str]
is_guest: bool
admin: bool
deactivated: bool
shadow_banned: bool
displayname: Optional[str]
avatar_url: Optional[str]
creation_ts: Optional[int]
approved: bool
erased: bool
last_seen_ts: int
locked: bool
class DataStore( class DataStore(
EventsBackgroundUpdatesStore, EventsBackgroundUpdatesStore,
ExperimentalFeaturesStore, ExperimentalFeaturesStore,
@ -156,7 +177,7 @@ class DataStore(
approved: bool = True, approved: bool = True,
not_user_types: Optional[List[str]] = None, not_user_types: Optional[List[str]] = None,
locked: bool = False, locked: bool = False,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[UserPaginateResponse], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
@ -182,7 +203,7 @@ class DataStore(
def get_users_paginate_txn( def get_users_paginate_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[UserPaginateResponse], int]:
filters = [] filters = []
args: list = [] args: list = []
@ -282,13 +303,24 @@ class DataStore(
""" """
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn) users = [
UserPaginateResponse(
# some of those boolean values are returned as integers when we're on SQLite name=row[0],
columns_to_boolify = ["erased"] user_type=row[1],
for user in users: is_guest=bool(row[2]),
for column in columns_to_boolify: admin=bool(row[3]),
user[column] = bool(user[column]) deactivated=bool(row[4]),
shadow_banned=bool(row[5]),
displayname=row[6],
avatar_url=row[7],
creation_ts=row[8],
approved=bool(row[9]),
erased=bool(row[10]),
last_seen_ts=row[11],
locked=bool(row[12]),
)
for row in txn
]
return users, count return users, count

View File

@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
# #
# For each duplicate, we delete all the existing rows and put one back. # For each duplicate, we delete all the existing rows and put one back.
KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
last_row = progress.get( last_row = progress.get(
"last_row", "last_row",
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn: LoggingTransaction) -> int: def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause( clause, args = make_tuple_comparison_clause(
[(x, last_row[x]) for x in KEY_COLS] [
("stream_id", last_row["stream_id"]),
("destination", last_row["destination"]),
("user_id", last_row["user_id"]),
("device_id", last_row["device_id"]),
]
) )
sql = """ sql = f"""
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
FROM device_lists_outbound_pokes FROM device_lists_outbound_pokes
WHERE %s WHERE {clause}
GROUP BY %s GROUP BY stream_id, destination, user_id, device_id
HAVING count(*) > 1 HAVING count(*) > 1
ORDER BY %s ORDER BY stream_id, destination, user_id, device_id
LIMIT ? LIMIT ?
""" % ( """
clause, # WHERE
",".join(KEY_COLS), # GROUP BY
",".join(KEY_COLS), # ORDER BY
)
txn.execute(sql, args + [batch_size]) txn.execute(sql, args + [batch_size])
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
row = None stream_id, destination, user_id, device_id = None, None, None, None
for row in rows: for stream_id, destination, user_id, device_id, _ in rows:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"device_lists_outbound_pokes", "device_lists_outbound_pokes",
{x: row[x] for x in KEY_COLS}, {
"stream_id": stream_id,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
},
) )
row["sent"] = False
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"device_lists_outbound_pokes", "device_lists_outbound_pokes",
row, {
"stream_id": stream_id,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
"sent": False,
},
) )
if row: if rows:
self.db_pool.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, txn,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
{"last_row": row}, {
"last_row": {
"stream_id": stream_id,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
}
},
) )
return len(rows) return len(rows)

View File

@ -26,6 +26,8 @@ from typing import (
cast, cast,
) )
import attr
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.media._base import ThumbnailInfo from synapse.media._base import ThumbnailInfo
@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
) )
@attr.s(slots=True, frozen=True, auto_attribs=True)
class LocalMedia:
media_id: str
media_type: str
media_length: int
upload_name: str
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
class MediaSortOrder(Enum): class MediaSortOrder(Enum):
""" """
Enum to define the sorting method used when returning media with Enum to define the sorting method used when returning media with
@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: str, user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value, order_by: str = MediaSortOrder.CREATED_TS.value,
direction: Direction = Direction.FORWARDS, direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[LocalMedia], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media
which an user_id has uploaded which an user_id has uploaded
@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn( def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[LocalMedia], int]:
# Set ordering # Set ordering
order_by_column = MediaSortOrder(order_by).value order_by_column = MediaSortOrder(order_by).value
@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = """ sql = """
SELECT SELECT
"media_id", media_id,
"media_type", media_type,
"media_length", media_length,
"upload_name", upload_name,
"created_ts", created_ts,
"last_access_ts", last_access_ts,
"quarantined_by", quarantined_by,
"safe_from_quarantine" safe_from_quarantine
FROM local_media_repository FROM local_media_repository
WHERE user_id = ? WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC ORDER BY {order_by_column} {order}, media_id ASC
@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)
media = self.db_pool.cursor_to_dict(txn) media = [
LocalMedia(
media_id=row[0],
media_type=row[1],
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
safe_from_quarantine=bool(row[7]),
)
for row in txn
]
return media, count return media, count
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(

View File

@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_registration_tokens( async def get_registration_tokens(
self, valid: Optional[bool] = None self, valid: Optional[bool] = None
) -> List[Dict[str, Any]]: ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
"""List all registration tokens. Used by the admin API. """List all registration tokens. Used by the admin API.
Args: Args:
@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Default is None: return all tokens regardless of validity. Default is None: return all tokens regardless of validity.
Returns: Returns:
A list of dicts, each containing details of a token. A list of tuples containing:
* The token
* The number of users allowed (or None)
* Whether it is pending
* Whether it has been completed
* An expiry time (or None if no expiry)
""" """
def select_registration_tokens_txn( def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool] txn: LoggingTransaction, now: int, valid: Optional[bool]
) -> List[Dict[str, Any]]: ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
if valid is None: if valid is None:
# Return all tokens regardless of validity # Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens") txn.execute(
"""
SELECT token, uses_allowed, pending, completed, expiry_time
FROM registration_tokens
"""
)
elif valid: elif valid:
# Select valid tokens only # Select valid tokens only
sql = ( sql = """
"SELECT * FROM registration_tokens WHERE " SELECT token, uses_allowed, pending, completed, expiry_time
"(uses_allowed > pending + completed OR uses_allowed IS NULL) " FROM registration_tokens
"AND (expiry_time > ? OR expiry_time IS NULL)" WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
) AND (expiry_time > ? OR expiry_time IS NULL)
"""
txn.execute(sql, [now]) txn.execute(sql, [now])
else: else:
# Select invalid tokens only # Select invalid tokens only
sql = ( sql = """
"SELECT * FROM registration_tokens WHERE " SELECT token, uses_allowed, pending, completed, expiry_time
"uses_allowed <= pending + completed OR expiry_time <= ?" FROM registration_tokens
) WHERE uses_allowed <= pending + completed OR expiry_time <= ?
"""
txn.execute(sql, [now]) txn.execute(sql, [now])
return self.db_pool.cursor_to_dict(txn) return cast(
List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"select_registration_tokens", "select_registration_tokens",

View File

@ -78,6 +78,31 @@ class RatelimitOverride:
burst_count: int burst_count: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
class LargestRoomStats:
room_id: str
name: Optional[str]
canonical_alias: Optional[str]
joined_members: int
join_rules: Optional[str]
guest_access: Optional[str]
history_visibility: Optional[str]
state_events: int
avatar: Optional[str]
topic: Optional[str]
room_type: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomStats(LargestRoomStats):
joined_local_members: int
version: Optional[str]
creator: Optional[str]
encryption: Optional[str]
federatable: bool
public: bool
class RoomSortOrder(Enum): class RoomSortOrder(Enum):
""" """
Enum to define the sorting method used when returning rooms with get_rooms_paginate Enum to define the sorting method used when returning rooms with get_rooms_paginate
@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
allow_none=True, allow_none=True,
) )
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics. """Retrieve room with statistics.
Args: Args:
@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_room_with_stats_txn( def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str txn: LoggingTransaction, room_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[RoomStats]:
sql = """ sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members, SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE room_id = ? WHERE room_id = ?
""" """
txn.execute(sql, [room_id]) txn.execute(sql, [room_id])
# Catch error if sql returns empty result to return "None" instead of an error row = txn.fetchone()
try: if not row:
res = self.db_pool.cursor_to_dict(txn)[0]
except IndexError:
return None return None
return RoomStats(
res["federatable"] = bool(res["federatable"]) room_id=row[0],
res["public"] = bool(res["public"]) name=row[1],
return res canonical_alias=row[2],
joined_members=row[3],
joined_local_members=row[4],
version=row[5],
creator=row[6],
encryption=row[7],
federatable=bool(row[8]),
public=bool(row[9]),
join_rules=row[10],
guest_access=row[11],
history_visibility=row[12],
state_events=row[13],
avatar=row[14],
topic=row[15],
room_type=row[16],
)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id "get_room_with_stats", get_room_with_stats_txn, room_id
@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
bounds: Optional[Tuple[int, str]], bounds: Optional[Tuple[int, str]],
forwards: bool, forwards: bool,
ignore_non_federatable: bool = False, ignore_non_federatable: bool = False,
) -> List[Dict[str, Any]]: ) -> List[LargestRoomStats]:
"""Gets the largest public rooms (where largest is in terms of joined """Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table). members, as tracked in the statistics table).
@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _get_largest_public_rooms_txn( def _get_largest_public_rooms_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Dict[str, Any]]: ) -> List[LargestRoomStats]:
txn.execute(sql, query_args) txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn) results = [
LargestRoomStats(
room_id=r[0],
name=r[1],
canonical_alias=r[3],
joined_members=r[4],
join_rules=r[8],
guest_access=r[7],
history_visibility=r[6],
state_events=0,
avatar=r[5],
topic=r[2],
room_type=r[9],
)
for r in txn
]
if not forwards: if not forwards:
results.reverse() results.reverse()
return results return results
ret_val = await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn "get_largest_public_rooms", _get_largest_public_rooms_txn
) )
return ret_val
@cached(max_entries=10000) @cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]: async def is_room_blocked(self, room_id: str) -> Optional[bool]:

View File

@ -342,10 +342,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated. # Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None assert room is not None
self.assertFalse(room["federatable"]) self.assertFalse(room.federatable)
self.assertFalse(room["public"]) self.assertFalse(room.public)
self.assertEqual(room["join_rules"], "public") self.assertEqual(room.join_rules, "public")
self.assertIsNone(room["guest_access"]) self.assertIsNone(room.guest_access)
# The user should be in the room. # The user should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@ -372,7 +372,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room. # Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None assert room is not None
self.assertEqual(room["join_rules"], "public") self.assertEqual(room.join_rules, "public")
# Both users should be in the room. # Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) rooms = self.get_success(self.store.get_rooms_for_user(inviter))
@ -411,9 +411,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room. # Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None assert room is not None
self.assertFalse(room["public"]) self.assertFalse(room.public)
self.assertEqual(room["join_rules"], "invite") self.assertEqual(room.join_rules, "invite")
self.assertEqual(room["guest_access"], "can_join") self.assertEqual(room.guest_access, "can_join")
# Both users should be in the room. # Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) rooms = self.get_success(self.store.get_rooms_for_user(inviter))
@ -455,9 +455,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room. # Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None assert room is not None
self.assertFalse(room["public"]) self.assertFalse(room.public)
self.assertEqual(room["join_rules"], "invite") self.assertEqual(room.join_rules, "invite")
self.assertEqual(room["guest_access"], "can_join") self.assertEqual(room.guest_access, "can_join")
# Both users should be in the room. # Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter)) rooms = self.get_success(self.store.get_rooms_for_user(inviter))

View File

@ -39,11 +39,11 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(1, total) self.assertEqual(1, total)
self.assertEqual(self.displayname, users.pop()["displayname"]) self.assertEqual(self.displayname, users.pop().displayname)
users, total = self.get_success( users, total = self.get_success(
self.store.get_users_paginate(0, 10, name="BC", guests=False) self.store.get_users_paginate(0, 10, name="BC", guests=False)
) )
self.assertEqual(1, total) self.assertEqual(1, total)
self.assertEqual(self.displayname, users.pop()["displayname"]) self.assertEqual(self.displayname, users.pop().displayname)

View File

@ -59,14 +59,9 @@ class RoomStoreTestCase(HomeserverTestCase):
def test_get_room_with_stats(self) -> None: def test_get_room_with_stats(self) -> None:
res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
assert res is not None assert res is not None
self.assertLessEqual( self.assertEqual(res.room_id, self.room.to_string())
{ self.assertEqual(res.creator, self.u_creator.to_string())
"room_id": self.room.to_string(), self.assertTrue(res.public)
"creator": self.u_creator.to_string(),
"public": True,
}.items(),
res.items(),
)
def test_get_room_with_stats_unknown_room(self) -> None: def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone( self.assertIsNone(