Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes
Patrick Cloke 2023-11-09 11:14:57 -05:00
commit 8c2d3d0b4c
48 changed files with 677 additions and 414 deletions

4
Cargo.lock generated
View File

@ -352,9 +352,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.107"
version = "1.0.108"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
dependencies = [
"itoa",
"ryu",

1
changelog.d/16611.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

1
changelog.d/16612.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -0,0 +1 @@
Improve the performance of claiming encryption keys in multi-worker deployments.

4
poetry.lock generated
View File

@ -2012,12 +2012,12 @@ plugins = ["importlib-metadata"]
[[package]]
name = "pyicu"
version = "2.11"
version = "2.12"
description = "Python extension wrapping the ICU C++ API"
optional = true
python-versions = "*"
files = [
{file = "PyICU-2.11.tar.gz", hash = "sha256:3ab531264cfe9132b3d2ac5d708da9a4649d25f6e6813730ac88cf040a08a844"},
{file = "PyICU-2.12.tar.gz", hash = "sha256:bd7ab5efa93ad692e6daa29cd249364e521218329221726a113ca3cb281c8611"},
]
[[package]]

View File

@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row
if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import (
AuthError,
@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size:
if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
media_info["media_length"],
media_info.media_length,
)
return False
@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes
and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
media_info["media_type"],
media_info.media_type,
)
return False

View File

@ -269,7 +269,7 @@ class RoomCreationHandler:
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
@ -279,7 +279,7 @@ class RoomCreationHandler:
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)

View File

@ -1274,7 +1274,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)

View File

@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True

View File

@ -19,6 +19,7 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import attr
from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
if not media_info or media_info.quarantined_by:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
media_length = media_info.media_length
upload_name = name if name else media_info.upload_name
url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
upload_name = name if name else media_info.upload_name
await respond_with_responder(
request, responder, media_type, media_length, upload_name
request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
)
else:
respond_404(request)
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
async def get_remote_media_info(
self, server_name: str, media_id: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
if not media_info.media_type:
media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
@ -403,9 +411,9 @@ class MediaRepository:
if not media_info:
raise e
file_id = media_info["filesystem_id"]
if not media_info["media_type"]:
media_info["media_type"] = "application/octet-stream"
file_id = media_info.filesystem_id
if not media_info.media_type:
media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
server_name, media_id, file_id, media_info.media_type
)
responder = await self.media_storage.fetch_media(file_info)
@ -425,7 +433,7 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
) -> dict:
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
return media_info
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements(
self, media_type: str

View File

@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
and cache_result["response_code"] / 100 == 2
and cache_result.expires_ts > ts
and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, str):
og = og.encode("utf8")
return og
if isinstance(cache_result.og, str):
return cache_result.og.encode("utf8")
return cache_result.og
# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url

View File

@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room:
return False
return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.

View File

@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)

View File

@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"

View File

@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
server_name=None,
)
@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info["url_cache"]),
url_cache=bool(media_info.url_cache),
)
if file_path:
@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
server_name, media_id
)
file_id = media_info["filesystem_id"]
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
file_id=file_id,
thumbnail=info,
)
@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
m_type,
thumbnail_infos,
media_id,
media_info["filesystem_id"],
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)

View File

@ -1116,7 +1116,7 @@ class DatabasePool:
def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
keys: Sequence[str],
values: Collection[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
@ -1597,7 +1597,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...
@overload
@ -1608,7 +1608,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...
async def simple_select_one(
@ -1618,7 +1618,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@ -2127,7 +2127,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues:
@ -2145,7 +2145,7 @@ class DatabasePool:
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row))
return row
async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"

View File

@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.
This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return
for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.
This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)

View File

@ -257,33 +257,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user(
self, user_id: str
@ -1223,9 +1206,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None
def _store_dehydrated_device_txn(
self,
@ -2328,13 +2309,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted.
"""
row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos(
self,

View File

@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there.
raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return {
"auth_data": db_to_json(row[2]),
"version": str(row[0]),
"algorithm": row[1],
"etag": 0 if row[3] is None else row[3],
}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn

View File

@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
self._invalidate_cache_and_stream_bulk(
txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)
return results
@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None:
continue
key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row
# Mark fallback key as used if not already.
if not used and mark_as_used:
@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device = {
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
}
self._invalidate_cache_and_stream_bulk(
txn,
self.count_e2e_one_time_keys,
seen_user_device,
)
return otk_rows

View File

@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
event_lookup_result["depth"],
event_lookup_result["stream_ordering"],
event_lookup_result["type"],
depth,
stream_ordering,
event_type,
)
if event_lookup_result["depth"]:
queue.put(
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
if depth:
queue.put((-depth, -stream_ordering, seed_event_id, event_type))
while not queue.empty() and len(event_id_results) < limit:
try:

View File

@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None:
return
redacted_relates_to = row["relates_to_id"]
rel_type = row["relation_type"]
redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)

View File

@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry

View File

@ -15,9 +15,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
@ -54,11 +52,32 @@ class LocalMedia:
media_length: int
upload_name: str
created_ts: int
url_cache: Optional[str]
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RemoteMedia:
media_origin: str
media_id: str
media_type: str
media_length: int
upload_name: Optional[str]
filesystem_id: str
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UrlCache:
response_code: int
expires_ts: int
og: Union[str, bytes]
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts",
"quarantined_by",
"url_cache",
"last_access_ts",
"safe_from_quarantine",
),
allow_none=True,
desc="get_local_media",
)
if row is None:
return None
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)
async def get_local_media_by_user_paginate(
self,
@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
upload_name,
created_ts,
url_cache,
last_access_ts,
quarantined_by,
safe_from_quarantine
@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
safe_from_quarantine=bool(row[7]),
url_cache=row[5],
last_access_ts=row[6],
quarantined_by=row[7],
safe_from_quarantine=bool(row[8]),
)
for row in txn
]
@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
)
sql = """
SELECT response_code, expires_ts, og
FROM local_media_repository_url_cache
WHERE url = ? AND download_ts <= ?
ORDER BY download_ts DESC LIMIT 1
"""
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
)
sql = """
SELECT response_code, expires_ts, og
FROM local_media_repository_url_cache
WHERE url = ? AND download_ts > ?
ORDER BY download_ts ASC LIMIT 1
"""
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
return dict(
zip(
(
"response_code",
"etag",
"expires_ts",
"og",
"media_id",
"download_ts",
),
row,
)
)
return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int,
etag: Optional[str],
expires_ts: int,
og: Optional[str],
og: str,
media_id: str,
download_ts: int,
) -> None:
@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media(
self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
) -> Optional[RemoteMedia]:
row = await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name",
"created_ts",
"filesystem_id",
"last_access_ts",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
if row is None:
return row
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)
async def store_cached_remote_media(
self,
@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int,
t_height: int,
t_type: str,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
if row is None:
return None
return ThumbnailInfo(
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
@trace
async def store_remote_media_thumbnail(

View File

@ -13,7 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
return ProfileInfo(None, None)
else:
raise
return ProfileInfo(
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
allow_none=True,
)
if profile is None:
# no match
return ProfileInfo(None, None)
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(

View File

@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
base_priority_class, base_rule_priority = res
if base_priority_class != priority_class:
raise InconsistentRuleException(

View File

@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
stream_ordering = int(res[0]) if res else None
rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts

View File

@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
ret_dict = await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
return (
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
return cast(
Tuple[str, int, Optional[int]],
await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
),
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
return self.db_pool.simple_select_one_onecol_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
["user_id"],
"user_id",
True,
)
if ret:
return ret["user_id"]
return None
async def user_add_threepid(
self,
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None:
return False
uses_allowed, pending, completed, expiry_time = res
# Check if the token has expired
now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now:
if expiry_time and expiry_time < now:
return False
# Check if the token has been used up
if (
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
if uses_allowed and pending + completed >= uses_allowed:
return False
# Otherwise, the token is valid
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res = cast(
Dict[str, Any],
pending, completed = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens",
keyvalues={"token": token},
updatevalues={
"completed": res["completed"] + 1,
"pending": res["pending"] - 1,
"completed": completed + 1,
"pending": pending - 1,
},
)
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
A dict, or None if token doesn't exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
if row is None:
return None
return {
"token": row[0],
"uses_allowed": row[1],
"pending": row[2],
"completed": row[3],
"expiry_time": row[4],
}
async def generate_registration_token(
self, length: int, chars: str
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None
# Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn(
result = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
if result is None:
return result
return {
"token": result[0],
"uses_allowed": result[1],
"pending": result[2],
"completed": result[3],
"expiry_time": result[4],
}
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
(
user_id,
expiry_ts,
used_ts,
auth_provider_id,
auth_provider_session_id,
) = values
# Token was already used
if used_ts is not None:
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
row = {"client_secret": None, "validated_at": None}
row = None, None
else:
raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
retrieved_client_secret, validated_at = row
row = self.db_pool.simple_select_one_txn(
txn,
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError(
"Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
expires, next_link = row
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(

View File

@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room.
Args:
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
A tuple containing the room information:
* True if the room is public
* True if the room has an auth chain index
or None if the room is unknown.
"""
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
row = cast(
Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("is_public", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
),
)
if row is None:
return row
return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
if row:
return RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
else:
return None
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join.
"""
result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
return cast(
Tuple[str, int],
await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
),
)
return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(

View File

@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,),
)
results_dict = await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
results = cast(
Optional[Tuple[str, str]],
await self.db_pool.simple_select_one(
"local_current_membership",
{"room_id": room_id, "user_id": user_id},
("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
),
)
if not results_dict:
if not results:
return None, None
return results_dict.get("membership"), results_dict.get("event_id")
return results
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(

View File

@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event",
)
return PersistedEventPosition(
row["instance_name"] or "master", row["stream_ordering"]
)
return PersistedEventPosition(row[1] or "master", row[0])
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
return RoomStreamToken(
topological=row["topological_ordering"], stream=row["stream_ordering"]
)
return RoomStreamToken(topological=row[1], stream=row[0])
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
results = self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
stream_ordering, topological_ordering = cast(
Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
),
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
topological=results["topological_ordering"] - 1,
stream=results["stream_ordering"],
topological=topological_ordering - 1, stream=stream_ordering
)
after_token = RoomStreamToken(
topological=results["topological_ordering"],
stream=results["stream_ordering"],
topological=topological_ordering, stream=stream_ordering
)
rows, start_token = self._paginate_room_events_txn(

View File

@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise
"""
row = await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
"id",
"action",
"status",
"timestamp",
"resource_id",
"params",
"result",
"error",
row = cast(
Optional[ScheduledTaskRow],
await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
"id",
"action",
"status",
"timestamp",
"resource_id",
"params",
"result",
"error",
),
allow_none=True,
desc="get_scheduled_task",
),
allow_none=True,
desc="get_scheduled_task",
)
return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.

View File

@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=(
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
retcols=("response_code", "response_json"),
allow_none=True,
)
if result and result["response_code"]:
return result["response_code"], db_to_json(result["response_json"])
# If the result exists and the response code is non-0.
if result and result[0]:
return result[0], db_to_json(result[1])
else:
return None
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
if result and result["retry_last_ts"]:
return DestinationRetryTimings(**result)
if result and result[1]:
return DestinationRetryTimings(
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
)
else:
return None

View File

@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
return UIAuthSessionData(
session_id,
clientdict=db_to_json(result[0]),
uri=result[1],
method=result[2],
description=result[3],
)
async def mark_ui_auth_stage_complete(
self,
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None:
# Get the current value.
result = cast(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
),
result = self.db_pool.simple_select_one_onecol_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcol="serverdict",
)
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
serverdict[key] = value
self.db_pool.simple_update_one_txn(
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
result = await self.db_pool.simple_select_one(
result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
retcol="serverdict",
desc="get_ui_auth_session_data",
)
serverdict = db_to_json(result["serverdict"])
serverdict = db_to_json(result)
return serverdict.get(key, default)

View File

@ -20,7 +20,6 @@ from typing import (
Collection,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
async def _get_user_in_directory(
self, user_id: str
) -> Optional[Tuple[Optional[str], Optional[str]]]:
"""
Fetch the user information in the user directory.
Returns:
None if the user is unknown, otherwise a tuple of display name and
avatar URL (both of which may be None).
"""
return cast(
Optional[Tuple[Optional[str], Optional[str]]],
await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
),
)
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:

View File

@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn)
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._mark_ids_as_finished, [next_id])
txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
"""
Usage:
stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
next_ids = self._load_next_mult_id_txn(txn, n)
txn.call_after(self._mark_ids_as_finished, next_ids)
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
# position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
)
return [self._return_factor * next_id for next_id in next_ids]
def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
"""These IDs have finished being processed so we should advance the
current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
self._unfinished_ids.difference_update(next_ids)
self._finished_ids.update(next_ids)
new_cur: Optional[int] = None
@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance
)
self._add_persisted_position(next_id)
# TODO Can we call this for just the last position or somehow batch
# _add_persisted_position.
for next_id in next_ids:
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
return self.get_persisted_upto_position()
@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
self.id_gen._mark_ids_as_finished(self.stream_ids)
self.notifier.notify_replication()

View File

@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
return self.get_success(
row = self.get_success(
self.store.db_pool.simple_select_one(
table + "_current",
{id_col: stat_id},
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
return None if row is None else dict(zip(cols, row))
def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update
self._add_background_updates()

View File

@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
self.assertTrue(profile[0] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory
profile = self.get_success(self.store._get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
self.assertEqual(profile[0], display_name)
# deactivate user
self.get_success(self.store.set_user_deactivated_status(r_user_id, True))

View File

@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
assert info is not None
file_id = info["filesystem_id"]
file_id = info.filesystem_id
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
origin, file_id

View File

@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
self.assertFalse(media_info.quarantined_by)
# quarantining
channel = self.make_request(
@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["quarantined_by"])
self.assertTrue(media_info.quarantined_by)
# remove from quarantine
channel = self.make_request(
@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None:
"""
@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
self.assertTrue(media_info.safe_from_quarantine)
# quarantining
channel = self.make_request(
@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
self.assertFalse(media_info.quarantined_by)
class ProtectMediaByIDTestCase(_AdminMediaTests):
@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])
self.assertFalse(media_info.safe_from_quarantine)
# protect
channel = self.make_request(
@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
self.assertTrue(media_info.safe_from_quarantine)
# unprotect
channel = self.make_request(
@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])
self.assertFalse(media_info.safe_from_quarantine)
class PurgeMediaCacheTestCase(_AdminMediaTests):

View File

@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User")
self.assertEqual(profile[0], "User")
# Deactivate user
channel = self.make_request(

View File

@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
#
# Note that we don't have the UI Auth session ID, so just pull out the single
# row.
ui_auth_data = self.get_success(
self.store.db_pool.simple_select_one(
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
result = self.get_success(
self.store.db_pool.simple_select_one_onecol(
"ui_auth_sessions", keyvalues={}, retcol="clientdict"
)
)
client_dict = db_to_json(ui_auth_data["clientdict"])
client_dict = db_to_json(result)
self.assertNotIn("new_password", client_dict)
@override_config({"rc_3pid_validation": {"burst_count": 3}})

View File

@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertLessEqual(det_data.items(), channel.json_body.items())
# Check the `completed` counter has been incremented and pending is 0
res = self.get_success(
pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEqual(res["completed"], 1)
self.assertEqual(res["pending"], 0)
self.assertEqual(completed, 1)
self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self) -> None:
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
res = self.get_success(
pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
self.assertEqual(res["pending"], 0)
self.assertEqual(res["completed"], 1)
self.assertEqual(pending, 0)
self.assertEqual(completed, 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)

View File

@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success(
self.store.get_local_media(mxc_uri.media_id)
found_media = bool(
self.get_success(self.store.get_local_media(mxc_uri.media_id))
)
else:
found_media_dict = self.get_success(
self.store.get_cached_remote_media(
mxc_uri.server_name, mxc_uri.media_id
found_media = bool(
self.get_success(
self.store.get_cached_remote_media(
mxc_uri.server_name, mxc_uri.media_id
)
)
)
if expect_purged:
self.assertIsNone(
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
)
self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
else:
self.assertIsNotNone(
found_media_dict,
self.assertTrue(
found_media,
msg=f"{mxc_uri} unexpectedly purged",
)

View File

@ -0,0 +1,117 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, call
from synapse.storage.database import LoggingTransaction
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase
class CacheInvalidationTestCase(HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation(self) -> None:
master_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
self.get_success(
self.store.db_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation_replicates(self) -> None:
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
master_invalidate = Mock()
worker_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
worker = self.make_worker_hs("synapse.app.generic_worker")
worker_ds = worker.get_datastores().main
worker_ds._get_cached_user_device.invalidate = worker_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
assert self.store._cache_id_gen is not None
initial_token = self.store._cache_id_gen.get_current_token()
self.get_success(
self.database_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
second_token = self.store._cache_id_gen.get_current_token()
self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
self.get_success(
worker.get_replication_data_handler().wait_for_stream_position(
"master", "caches", second_token
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
worker_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)

View File

@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
self.assertEqual((1, 2, 3), ret)
self.mock_txn.execute.assert_called_once_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
)
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
self.assertFalse(ret)
self.assertIsNone(ret)
@defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:

View File

@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
)
def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None
self.assertLessEqual(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
}.items(),
res.items(),
)
room = self.get_success(self.store.get_room(self.room.to_string()))
assert room is not None
self.assertTrue(room[0])
def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))