Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes
Patrick Cloke 2023-11-09 11:14:57 -05:00
commit 8c2d3d0b4c
48 changed files with 677 additions and 414 deletions

4
Cargo.lock generated
View File

@ -352,9 +352,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.107" version = "1.0.108"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",

1
changelog.d/16611.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

1
changelog.d/16612.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -0,0 +1 @@
Improve the performance of claiming encryption keys in multi-worker deployments.

4
poetry.lock generated
View File

@ -2012,12 +2012,12 @@ plugins = ["importlib-metadata"]
[[package]] [[package]]
name = "pyicu" name = "pyicu"
version = "2.11" version = "2.12"
description = "Python extension wrapping the ICU C++ API" description = "Python extension wrapping the ICU C++ API"
optional = true optional = true
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "PyICU-2.11.tar.gz", hash = "sha256:3ab531264cfe9132b3d2ac5d708da9a4649d25f6e6813730ac88cf040a08a844"}, {file = "PyICU-2.12.tar.gz", hash = "sha256:bd7ab5efa93ad692e6daa29cd249364e521218329221726a113ca3cb281c8611"},
] ]
[[package]] [[package]]

View File

@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0 backward_chunk = 0
already_ported = 0 already_ported = 0
else: else:
forward_chunk = row["forward_rowid"] forward_chunk, backward_chunk = row
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port( already_ported, total_to_port = await self._get_total_count_to_port(

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError, StoreError,
SynapseError, SynapseError,
) )
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri from synapse.util.stringutils import parse_and_validate_mxc_uri
@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host server_name = host
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id) media_info: Optional[
Union[LocalMedia, RemoteMedia]
] = await self.store.get_local_media(media_id)
else: else:
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size: if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size # Ensure avatar does not exceed max allowed avatar size
if media_info["media_length"] > self.max_avatar_size: if media_info.media_length > self.max_avatar_size:
logger.warning( logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size " "Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit", "limit",
mxc, mxc,
media_info["media_length"], media_info.media_length,
) )
return False return False
@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed # Ensure the avatar's file type is allowed
if ( if (
self.allowed_avatar_mimetypes self.allowed_avatar_mimetypes
and media_info["media_type"] not in self.allowed_avatar_mimetypes and media_info.media_type not in self.allowed_avatar_mimetypes
): ):
logger.warning( logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed", "Forbidding avatar change to %s: mimetype %s not allowed",
mxc, mxc,
media_info["media_type"], media_info.media_type,
) )
return False return False

View File

@ -269,7 +269,7 @@ class RoomCreationHandler:
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
old_room: Dict[str, Any], old_room: Tuple[bool, str, bool],
new_room_id: str, new_room_id: str,
new_version: RoomVersion, new_version: RoomVersion,
tombstone_event: EventBase, tombstone_event: EventBase,
@ -279,7 +279,7 @@ class RoomCreationHandler:
Args: Args:
requester: the user requesting the upgrade requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced, old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`. as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room new_room_id: the id of the replacement room
new_version: the version to upgrade the room to new_version: the version to upgrade the room to
@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room( await self.store.store_room(
room_id=new_room_id, room_id=new_room_id,
room_creator_user_id=user_id, room_creator_user_id=user_id,
is_public=old_room["is_public"], is_public=old_room[0],
room_version=new_version, room_version=new_version,
) )

View File

@ -1274,7 +1274,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there # Add new room to the room directory if the old room was there
# Remove old room from the room directory # Remove old room from the room directory
old_room = await self.store.get_room(old_room_id) old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]: # If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True) await self.store.set_room_is_public(room_id, True)

View File

@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1] media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id) media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]: if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar") logger.info("skipping saving the user avatar")
return True return True

View File

@ -19,6 +19,7 @@ import shutil
from io import BytesIO from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import attr
from matrix_common.types.mxc_uri import MXCUri from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error import twisted.internet.error
@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
""" """
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info or media_info.quarantined_by:
respond_404(request) respond_404(request)
return return
self.mark_recently_accessed(None, media_id) self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"] media_type = media_info.media_type
if not media_type: if not media_type:
media_type = "application/octet-stream" media_type = "application/octet-stream"
media_length = media_info["media_length"] media_length = media_info.media_length
upload_name = name if name else media_info["upload_name"] upload_name = name if name else media_info.upload_name
url_cache = media_info["url_cache"] url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
if responder: if responder:
media_type = media_info["media_type"] upload_name = name if name else media_info.upload_name
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
await respond_with_responder( await respond_with_responder(
request, responder, media_type, media_length, upload_name request,
responder,
media_info.media_type,
media_info.media_length,
upload_name,
) )
else: else:
respond_404(request) respond_404(request)
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: async def get_remote_media_info(
self, server_name: str, media_id: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl( async def _get_remote_media_impl(
self, server_name: str, media_id: str self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it # If we have an entry in the DB, try and look for it
if media_info: if media_info:
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id) file_info = FileInfo(server_name, file_id)
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
raise NotFoundError() raise NotFoundError()
if not media_info["media_type"]: if not media_info.media_type:
media_info["media_type"] = "application/octet-stream" media_info = attr.evolve(
media_info, media_type="application/octet-stream"
)
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
if responder: if responder:
@ -403,9 +411,9 @@ class MediaRepository:
if not media_info: if not media_info:
raise e raise e
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
if not media_info["media_type"]: if not media_info.media_type:
media_info["media_type"] = "application/octet-stream" media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id) file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media # We generate thumbnails even if another process downloaded the media
@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not # otherwise they'll request thumbnails and get a 404 if they're not
# ready yet. # ready yet.
await self._generate_thumbnails( await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"] server_name, media_id, file_id, media_info.media_type
) )
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
@ -425,7 +433,7 @@ class MediaRepository:
self, self,
server_name: str, server_name: str,
media_id: str, media_id: str,
) -> dict: ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name, origin=server_name,
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
time_now_ms=self.clock.time_msec(), time_now_ms=time_now_ms,
upload_name=upload_name, upload_name=upload_name,
media_length=length, media_length=length,
filesystem_id=file_id, filesystem_id=file_id,
@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname) logger.info("Stored remote media in file %r", fname)
media_info = { return RemoteMedia(
"media_type": media_type, media_origin=server_name,
"media_length": length, media_id=media_id,
"upload_name": upload_name, media_type=media_type,
"created_ts": time_now_ms, media_length=length,
"filesystem_id": file_id, upload_name=upload_name,
} created_ts=time_now_ms,
filesystem_id=file_id,
return media_info last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements( def _get_thumbnail_requirements(
self, media_type: str self, media_type: str

View File

@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts) cache_result = await self.store.get_url_cache(url, ts)
if ( if (
cache_result cache_result
and cache_result["expires_ts"] > ts and cache_result.expires_ts > ts
and cache_result["response_code"] / 100 == 2 and cache_result.response_code // 100 == 2
): ):
# It may be stored as text in the database, not as bytes (such as # It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on. # PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"] if isinstance(cache_result.og, str):
if isinstance(og, str): return cache_result.og.encode("utf8")
og = og.encode("utf8") return cache_result.og
return og
# If this URL can be accessed via an allowed oEmbed, use that instead. # If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url url_to_download = url

View File

@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room: if not room:
return False return False
return room.get("is_public", False) # The first item is whether the room is public.
return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None: async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list. """Publishes a room to the public room list.

View File

@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not ret: if not room:
raise NotFoundError("Room not found") raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id) members = await self.store.get_users_in_room(room_id)
@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not ret: if not room:
raise NotFoundError("Room not found") raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)

View File

@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None: if room is None:
raise NotFoundError("Unknown room") raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"} return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel): class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public" visibility: Literal["public", "private"] = "public"

View File

@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
if not media_info: if not media_info:
respond_404(request) respond_404(request)
return return
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos, thumbnail_infos,
media_id, media_id,
media_id, media_id,
url_cache=bool(media_info["url_cache"]), url_cache=bool(media_info.url_cache),
server_name=None, server_name=None,
) )
@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
if not media_info: if not media_info:
respond_404(request) respond_404(request)
return return
if media_info["quarantined_by"]: if media_info.quarantined_by:
logger.info("Media is quarantined") logger.info("Media is quarantined")
respond_404(request) respond_404(request)
return return
@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo( file_info = FileInfo(
server_name=None, server_name=None,
file_id=media_id, file_id=media_id,
url_cache=media_info["url_cache"], url_cache=bool(media_info.url_cache),
thumbnail=info, thumbnail=info,
) )
@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
desired_height, desired_height,
desired_method, desired_method,
desired_type, desired_type,
url_cache=bool(media_info["url_cache"]), url_cache=bool(media_info.url_cache),
) )
if file_path: if file_path:
@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
server_name, media_id server_name, media_id
) )
file_id = media_info["filesystem_id"] file_id = media_info.filesystem_id
for info in thumbnail_infos: for info in thumbnail_infos:
t_w = info.width == desired_width t_w = info.width == desired_width
@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type: if t_w and t_h and t_method and t_type:
file_info = FileInfo( file_info = FileInfo(
server_name=server_name, server_name=server_name,
file_id=media_info["filesystem_id"], file_id=file_id,
thumbnail=info, thumbnail=info,
) )
@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
m_type, m_type,
thumbnail_infos, thumbnail_infos,
media_id, media_id,
media_info["filesystem_id"], media_info.filesystem_id,
url_cache=False, url_cache=False,
server_name=server_name, server_name=server_name,
) )

View File

@ -1116,7 +1116,7 @@ class DatabasePool:
def simple_insert_many_txn( def simple_insert_many_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
table: str, table: str,
keys: Collection[str], keys: Sequence[str],
values: Collection[Iterable[Any]], values: Collection[Iterable[Any]],
) -> None: ) -> None:
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
@ -1597,7 +1597,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Dict[str, Any]: ) -> Tuple[Any, ...]:
... ...
@overload @overload
@ -1608,7 +1608,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[True] = True, allow_none: Literal[True] = True,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
... ...
async def simple_select_one( async def simple_select_one(
@ -1618,7 +1618,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: bool = False, allow_none: bool = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it. return a single row, returning multiple columns from it.
@ -2127,7 +2127,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcols: Collection[str], retcols: Collection[str],
allow_none: bool = False, allow_none: bool = False,
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues: if keyvalues:
@ -2145,7 +2145,7 @@ class DatabasePool:
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,)) raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row)) return row
async def simple_delete_one( async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"

View File

@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.
This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return
for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)
def _invalidate_all_cache_and_stream( def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None: ) -> None:
@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None assert self._cache_id_gen is not None
# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn) stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
}, },
) )
def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.
This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.
Note that this does *not* invalidate the cache locally.
Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)
def get_cache_stream_token_for_writer(self, instance_name: str) -> int: def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name) return self._cache_id_gen.get_current_token_for_writer(instance_name)

View File

@ -257,33 +257,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not A dict containing the device information, or `None` if the device does not
exist. exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
desc="get_device", desc="get_device",
allow_none=True, allow_none=True,
) )
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user( async def get_devices_by_user(
self, user_id: str self, user_id: str
@ -1223,9 +1206,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"], retcols=["device_id", "device_data"],
allow_none=True, allow_none=True,
) )
return ( return (row[0], json_decoder.decode(row[1])) if row else None
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
def _store_dehydrated_device_txn( def _store_dehydrated_device_txn(
self, self,
@ -2328,13 +2309,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted. `FALSE` have not been converted.
""" """
row = await self.db_pool.simple_select_one( return cast(
table="device_lists_changes_converted_stream_position", Tuple[int, str],
keyvalues={}, await self.db_pool.simple_select_one(
retcols=["stream_id", "room_id"], table="device_lists_changes_converted_stream_position",
desc="get_device_change_last_converted_pos", keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
) )
return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos( async def set_device_change_last_converted_pos(
self, self,

View File

@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there. # it isn't there.
raise StoreError(404, "No backup with that version exists") raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn( row = cast(
txn, Tuple[int, str, str, Optional[int]],
table="e2e_room_keys_versions", self.db_pool.simple_select_one_txn(
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, txn,
retcols=("version", "algorithm", "auth_data", "etag"), table="e2e_room_keys_versions",
allow_none=False, keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
) )
assert result is not None # see comment on `simple_select_one_txn` return {
result["auth_data"] = db_to_json(result["auth_data"]) "auth_data": db_to_json(row[2]),
result["version"] = str(result["version"]) "version": str(row[0]),
if result["etag"] is None: "algorithm": row[1],
result["etag"] = 0 "etag": 0 if row[3] is None else row[3],
return result }
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn

View File

@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys: for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id)) seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) self._invalidate_cache_and_stream_bulk(
) txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)
return results return results
@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None: if row is None:
continue continue
key_id = row["key_id"] key_id, key_json, used = row
key_json = row["key_json"]
used = row["used"]
# Mark fallback key as used if not already. # Mark fallback key as used if not already.
if not used and mark_as_used: if not used and mark_as_used:
@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
) )
seen_user_device: Set[Tuple[str, str]] = set() seen_user_device = {
for user_id, device_id, _, _, _ in otk_rows: (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
if (user_id, device_id) in seen_user_device: }
continue self._invalidate_cache_and_stream_bulk(
seen_user_device.add((user_id, device_id)) txn,
self._invalidate_cache_and_stream( self.count_e2e_one_time_keys,
txn, self.count_e2e_one_time_keys, (user_id, device_id) seen_user_device,
) )
return otk_rows return otk_rows

View File

@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined] room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: # If the room has an auth chain index.
if room[1]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains", "get_auth_chain_ids_chains",
@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover # Check if we have indexed the room so we can use the chain cover
# algorithm. # algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined] room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]: # If the room has an auth chain index.
if room[1]:
try: try:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains", "get_auth_chain_difference_chains",
@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
if event_lookup_result is not None: if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug( logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id, room_id,
seed_event_id, seed_event_id,
event_lookup_result["depth"], depth,
event_lookup_result["stream_ordering"], stream_ordering,
event_lookup_result["type"], event_type,
) )
if event_lookup_result["depth"]: if depth:
queue.put( queue.put((-depth, -stream_ordering, seed_event_id, event_type))
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
while not queue.empty() and len(event_id_results) < limit: while not queue.empty() and len(event_id_results) < limit:
try: try:

View File

@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None: if row is None:
return return
redacted_relates_to = row["relates_to_id"] redacted_relates_to, rel_type = row
rel_type = row["relation_type"]
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
) )

View File

@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res: if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,)) raise SynapseError(404, "Could not find event %s" % (event_id,))
return int(res["topological_ordering"]), int(res["stream_ordering"]) return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry """Retrieve the entry with the lowest expiry timestamp in the event_expiry

View File

@ -15,9 +15,7 @@
from enum import Enum from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Collection, Collection,
Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -54,11 +52,32 @@ class LocalMedia:
media_length: int media_length: int
upload_name: str upload_name: str
created_ts: int created_ts: int
url_cache: Optional[str]
last_access_ts: int last_access_ts: int
quarantined_by: Optional[str] quarantined_by: Optional[str]
safe_from_quarantine: bool safe_from_quarantine: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RemoteMedia:
media_origin: str
media_id: str
media_type: str
media_length: int
upload_name: Optional[str]
filesystem_id: str
created_ts: int
last_access_ts: int
quarantined_by: Optional[str]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UrlCache:
response_code: int
expires_ts: int
og: Union[str, bytes]
class MediaSortOrder(Enum): class MediaSortOrder(Enum):
""" """
Enum to define the sorting method used when returning media with Enum to define the sorting method used when returning media with
@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media
Returns: Returns:
None if the media_id doesn't exist. None if the media_id doesn't exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts", "created_ts",
"quarantined_by", "quarantined_by",
"url_cache", "url_cache",
"last_access_ts",
"safe_from_quarantine", "safe_from_quarantine",
), ),
allow_none=True, allow_none=True,
desc="get_local_media", desc="get_local_media",
) )
if row is None:
return None
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
self, self,
@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length, media_length,
upload_name, upload_name,
created_ts, created_ts,
url_cache,
last_access_ts, last_access_ts,
quarantined_by, quarantined_by,
safe_from_quarantine safe_from_quarantine
@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2], media_length=row[2],
upload_name=row[3], upload_name=row[3],
created_ts=row[4], created_ts=row[4],
last_access_ts=row[5], url_cache=row[5],
quarantined_by=row[6], last_access_ts=row[6],
safe_from_quarantine=bool(row[7]), quarantined_by=row[7],
safe_from_quarantine=bool(row[8]),
) )
for row in txn for row in txn
] ]
@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe", desc="mark_local_media_as_safe",
) )
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp """Get the media_id and ts for a cached URL as of the given timestamp
Returns: Returns:
None if the URL isn't cached. None if the URL isn't cached.
""" """
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts) # get the most recently cached result (relative to the given ts)
sql = ( sql = """
"SELECT response_code, etag, expires_ts, og, media_id, download_ts" SELECT response_code, expires_ts, og
" FROM local_media_repository_url_cache" FROM local_media_repository_url_cache
" WHERE url = ? AND download_ts <= ?" WHERE url = ? AND download_ts <= ?
" ORDER BY download_ts DESC LIMIT 1" ORDER BY download_ts DESC LIMIT 1
) """
txn.execute(sql, (url, ts)) txn.execute(sql, (url, ts))
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
# ...or if we've requested a timestamp older than the oldest # ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any) # copy in the cache, return the oldest copy (if any)
sql = ( sql = """
"SELECT response_code, etag, expires_ts, og, media_id, download_ts" SELECT response_code, expires_ts, og
" FROM local_media_repository_url_cache" FROM local_media_repository_url_cache
" WHERE url = ? AND download_ts > ?" WHERE url = ? AND download_ts > ?
" ORDER BY download_ts ASC LIMIT 1" ORDER BY download_ts ASC LIMIT 1
) """
txn.execute(sql, (url, ts)) txn.execute(sql, (url, ts))
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
return None return None
return dict( return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
zip(
(
"response_code",
"etag",
"expires_ts",
"og",
"media_id",
"download_ts",
),
row,
)
)
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int, response_code: int,
etag: Optional[str], etag: Optional[str],
expires_ts: int, expires_ts: int,
og: Optional[str], og: str,
media_id: str, media_id: str,
download_ts: int, download_ts: int,
) -> None: ) -> None:
@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media( async def get_cached_remote_media(
self, origin: str, media_id: str self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[RemoteMedia]:
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"remote_media_cache", "remote_media_cache",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name", "upload_name",
"created_ts", "created_ts",
"filesystem_id", "filesystem_id",
"last_access_ts",
"quarantined_by", "quarantined_by",
), ),
allow_none=True, allow_none=True,
desc="get_cached_remote_media", desc="get_cached_remote_media",
) )
if row is None:
return row
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)
async def store_cached_remote_media( async def store_cached_remote_media(
self, self,
@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int, t_width: int,
t_height: int, t_height: int,
t_type: str, t_type: str,
) -> Optional[Dict[str, Any]]: ) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type.""" """Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails", table="remote_media_cache_thumbnails",
keyvalues={ keyvalues={
"media_origin": origin, "media_origin": origin,
@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method", "thumbnail_method",
"thumbnail_type", "thumbnail_type",
"thumbnail_length", "thumbnail_length",
"filesystem_id",
), ),
allow_none=True, allow_none=True,
desc="get_remote_media_thumbnail", desc="get_remote_media_thumbnail",
) )
if row is None:
return None
return ThumbnailInfo(
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
@trace @trace
async def store_remote_media_thumbnail( async def store_remote_media_thumbnail(

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50 return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try: profile = await self.db_pool.simple_select_one(
profile = await self.db_pool.simple_select_one( table="profiles",
table="profiles", keyvalues={"full_user_id": user_id.to_string()},
keyvalues={"full_user_id": user_id.to_string()}, retcols=("displayname", "avatar_url"),
retcols=("displayname", "avatar_url"), desc="get_profileinfo",
desc="get_profileinfo", allow_none=True,
)
except StoreError as e:
if e.code == 404:
# no match
return ProfileInfo(None, None)
else:
raise
return ProfileInfo(
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
) )
if profile is None:
# no match
return ProfileInfo(None, None)
return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(

View File

@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
"before/after rule not found: %s" % (relative_to_rule,) "before/after rule not found: %s" % (relative_to_rule,)
) )
base_priority_class = res["priority_class"] base_priority_class, base_rule_priority = res
base_rule_priority = res["priority"]
if base_priority_class != priority_class: if base_priority_class != priority_class:
raise InconsistentRuleException( raise InconsistentRuleException(

View File

@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
stream_ordering = int(res["stream_ordering"]) if res else None stream_ordering = int(res[0]) if res else None
rx_ts = res["received_ts"] if res else 0 rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts

View File

@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet. has not been renewed using the current token yet.
""" """
ret_dict = await self.db_pool.simple_select_one( return cast(
table="account_validity", Tuple[str, int, Optional[int]],
keyvalues={"renewal_token": renewal_token}, await self.db_pool.simple_select_one(
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], table="account_validity",
desc="get_user_from_renewal_token", keyvalues={"renewal_token": renewal_token},
) retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
return ( ),
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
) )
async def get_renewal_token_for_user(self, user_id: str) -> str: async def get_renewal_token_for_user(self, user_id: str) -> str:
@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
user id, or None if no user id/threepid mapping exists user id, or None if no user id/threepid mapping exists
""" """
ret = self.db_pool.simple_select_one_txn( return self.db_pool.simple_select_one_onecol_txn(
txn, txn,
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
["user_id"], "user_id",
True, True,
) )
if ret:
return ret["user_id"]
return None
async def user_add_threepid( async def user_add_threepid(
self, self,
@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None: if res is None:
return False return False
uses_allowed, pending, completed, expiry_time = res
# Check if the token has expired # Check if the token has expired
now = self._clock.time_msec() now = self._clock.time_msec()
if res["expiry_time"] and res["expiry_time"] < now: if expiry_time and expiry_time < now:
return False return False
# Check if the token has been used up # Check if the token has been used up
if ( if uses_allowed and pending + completed >= uses_allowed:
res["uses_allowed"]
and res["pending"] + res["completed"] >= res["uses_allowed"]
):
return False return False
# Otherwise, the token is valid # Otherwise, the token is valid
@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if # Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors # allow_none is True, and we don't want mypy throwing errors
# about None not being indexable. # about None not being indexable.
res = cast( pending, completed = cast(
Dict[str, Any], Tuple[int, int],
self.db_pool.simple_select_one_txn( self.db_pool.simple_select_one_txn(
txn, txn,
"registration_tokens", "registration_tokens",
@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
updatevalues={ updatevalues={
"completed": res["completed"] + 1, "completed": completed + 1,
"pending": res["pending"] - 1, "pending": pending - 1,
}, },
) )
@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
A dict, or None if token doesn't exist. A dict, or None if token doesn't exist.
""" """
return await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True, allow_none=True,
desc="get_one_registration_token", desc="get_one_registration_token",
) )
if row is None:
return None
return {
"token": row[0],
"uses_allowed": row[1],
"pending": row[2],
"completed": row[3],
"expiry_time": row[4],
}
async def generate_registration_token( async def generate_registration_token(
self, length: int, chars: str self, length: int, chars: str
@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None return None
# Get all info about the token so it can be sent in the response # Get all info about the token so it can be sent in the response
return self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True, allow_none=True,
) )
if result is None:
return result
return {
"token": result[0],
"uses_allowed": result[1],
"pending": result[2],
"completed": result[3],
"expiry_time": result[4],
}
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn "update_registration_token", _update_registration_token_txn
) )
@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token}, keyvalues={"token": token},
updatevalues={"used_ts": ts}, updatevalues={"used_ts": ts},
) )
user_id = values["user_id"] (
expiry_ts = values["expiry_ts"] user_id,
used_ts = values["used_ts"] expiry_ts,
auth_provider_id = values["auth_provider_id"] used_ts,
auth_provider_session_id = values["auth_provider_session_id"] auth_provider_id,
auth_provider_session_id,
) = values
# Token was already used # Token was already used
if used_ts is not None: if used_ts is not None:
@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL, # reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by # so we don't have to worry about the client secret matching by
# accident. # accident.
row = {"client_secret": None, "validated_at": None} row = None, None
else: else:
raise ThreepidValidationError("Unknown session_id") raise ThreepidValidationError("Unknown session_id")
retrieved_client_secret = row["client_secret"] retrieved_client_secret, validated_at = row
validated_at = row["validated_at"]
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError( raise ThreepidValidationError(
"Validation token not found or has expired" "Validation token not found or has expired"
) )
expires = row["expires"] expires, next_link = row
next_link = row["next_link"]
if retrieved_client_secret != client_secret: if retrieved_client_secret != client_secret:
raise ThreepidValidationError( raise ThreepidValidationError(

View File

@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room. """Retrieve a room.
Args: Args:
room_id: The ID of the room to retrieve. room_id: The ID of the room to retrieve.
Returns: Returns:
A dict containing the room information, or None if the room is unknown. A tuple containing the room information:
* True if the room is public
* True if the room has an auth chain index
or None if the room is unknown.
""" """
return await self.db_pool.simple_select_one( row = cast(
table="rooms", Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
keyvalues={"room_id": room_id}, await self.db_pool.simple_select_one(
retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), table="rooms",
desc="get_room", keyvalues={"room_id": room_id},
allow_none=True, retcols=("is_public", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
),
) )
if row is None:
return row
return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics. """Retrieve room with statistics.
@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
) )
if row: if row:
return RatelimitOverride( return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
else: else:
return None return None
@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join. join.
""" """
result = await self.db_pool.simple_select_one( return cast(
table="partial_state_rooms", Tuple[str, int],
keyvalues={"room_id": room_id}, await self.db_pool.simple_select_one(
retcols=("join_event_id", "device_lists_stream_id"), table="partial_state_rooms",
desc="get_join_event_id_for_partial_state", keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
),
) )
return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(

View File

@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,), "non-local user %s" % (user_id,),
) )
results_dict = await self.db_pool.simple_select_one( results = cast(
"local_current_membership", Optional[Tuple[str, str]],
{"room_id": room_id, "user_id": user_id}, await self.db_pool.simple_select_one(
("membership", "event_id"), "local_current_membership",
allow_none=True, {"room_id": room_id, "user_id": user_id},
desc="get_local_current_membership_for_user_in_room", ("membership", "event_id"),
allow_none=True,
desc="get_local_current_membership_for_user_in_room",
),
) )
if not results_dict: if not results:
return None, None return None, None
return results_dict.get("membership"), results_dict.get("event_id") return results
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering( async def get_rooms_for_user_with_stream_ordering(

View File

@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event", desc="get_position_for_event",
) )
return PersistedEventPosition( return PersistedEventPosition(row[1] or "master", row[0])
row["instance_name"] or "master", row["stream_ordering"]
)
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event """The stream token for an event
@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"), retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event", desc="get_topological_token_for_event",
) )
return RoomStreamToken( return RoomStreamToken(topological=row[1], stream=row[0])
topological=row["topological_ordering"], stream=row["stream_ordering"]
)
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream """Gets the topological token in a room after or at the given stream
@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict dict
""" """
results = self.db_pool.simple_select_one_txn( stream_ordering, topological_ordering = cast(
txn, Tuple[int, int],
"events", self.db_pool.simple_select_one_txn(
keyvalues={"event_id": event_id, "room_id": room_id}, txn,
retcols=["stream_ordering", "topological_ordering"], "events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
),
) )
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating # Paginating backwards includes the event at the token, but paginating
# forward doesn't. # forward doesn't.
before_token = RoomStreamToken( before_token = RoomStreamToken(
topological=results["topological_ordering"] - 1, topological=topological_ordering - 1, stream=stream_ordering
stream=results["stream_ordering"],
) )
after_token = RoomStreamToken( after_token = RoomStreamToken(
topological=results["topological_ordering"], topological=topological_ordering, stream=stream_ordering
stream=results["stream_ordering"],
) )
rows, start_token = self._paginate_room_events_txn( rows, start_token = self._paginate_room_events_txn(

View File

@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise Returns: the task if available, `None` otherwise
""" """
row = await self.db_pool.simple_select_one( row = cast(
table="scheduled_tasks", Optional[ScheduledTaskRow],
keyvalues={"id": id}, await self.db_pool.simple_select_one(
retcols=( table="scheduled_tasks",
"id", keyvalues={"id": id},
"action", retcols=(
"status", "id",
"timestamp", "action",
"resource_id", "status",
"params", "timestamp",
"result", "resource_id",
"error", "params",
"result",
"error",
),
allow_none=True,
desc="get_scheduled_task",
), ),
allow_none=True,
desc="get_scheduled_task",
) )
return ( return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
async def delete_scheduled_task(self, id: str) -> None: async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id. """Delete a specific task from its id.

View File

@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn, txn,
table="received_transactions", table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin}, keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=( retcols=("response_code", "response_json"),
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
allow_none=True, allow_none=True,
) )
if result and result["response_code"]: # If the result exists and the response code is non-0.
return result["response_code"], db_to_json(result["response_json"]) if result and result[0]:
return result[0], db_to_json(result[1])
else: else:
return None return None
@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero # check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative) # (retry_last_ts can't be negative)
if result and result["retry_last_ts"]: if result and result[1]:
return DestinationRetryTimings(**result) return DestinationRetryTimings(
failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
)
else: else:
return None return None

View File

@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session", desc="get_ui_auth_session",
) )
result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(
session_id,
return UIAuthSessionData(session_id, **result) clientdict=db_to_json(result[0]),
uri=result[1],
method=result[2],
description=result[3],
)
async def mark_ui_auth_stage_complete( async def mark_ui_auth_stage_complete(
self, self,
@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None: ) -> None:
# Get the current value. # Get the current value.
result = cast( result = self.db_pool.simple_select_one_onecol_txn(
Dict[str, Any], txn,
self.db_pool.simple_select_one_txn( table="ui_auth_sessions",
txn, keyvalues={"session_id": session_id},
table="ui_auth_sessions", retcol="serverdict",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
),
) )
# Update it and add it back to the database. # Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result)
serverdict[key] = value serverdict[key] = value
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises: Raises:
StoreError if the session cannot be found. StoreError if the session cannot be found.
""" """
result = await self.db_pool.simple_select_one( result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("serverdict",), retcol="serverdict",
desc="get_ui_auth_session_data", desc="get_ui_auth_session_data",
) )
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result)
return serverdict.get(key, default) return serverdict.get(key, default)

View File

@ -20,7 +20,6 @@ from typing import (
Collection, Collection,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn "delete_all_from_user_dir", _delete_all_from_user_dir_txn
) )
async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: async def _get_user_in_directory(
return await self.db_pool.simple_select_one( self, user_id: str
table="user_directory", ) -> Optional[Tuple[Optional[str], Optional[str]]]:
keyvalues={"user_id": user_id}, """
retcols=("display_name", "avatar_url"), Fetch the user information in the user directory.
allow_none=True,
desc="get_user_in_directory", Returns:
None if the user is unknown, otherwise a tuple of display name and
avatar URL (both of which may be None).
"""
return cast(
Optional[Tuple[Optional[str], Optional[str]]],
await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
),
) )
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:

View File

@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn) next_id = self._load_next_id_txn(txn)
txn.call_after(self._mark_id_as_finished, next_id) txn.call_after(self._mark_ids_as_finished, [next_id])
txn.call_on_exception(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication) txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream # Update the `stream_positions` table with newly updated stream
@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int) -> None: def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
"""The ID has finished being processed so we should advance the """
Usage:
stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")
next_ids = self._load_next_mult_id_txn(txn, n)
txn.call_after(self._mark_ids_as_finished, next_ids)
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
# position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
)
return [self._return_factor * next_id for next_id in next_ids]
def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
"""These IDs have finished being processed so we should advance the
current position if possible. current position if possible.
""" """
with self._lock: with self._lock:
self._unfinished_ids.discard(next_id) self._unfinished_ids.difference_update(next_ids)
self._finished_ids.add(next_id) self._finished_ids.update(next_ids)
new_cur: Optional[int] = None new_cur: Optional[int] = None
@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance curr, new_cur, self._max_position_of_local_instance
) )
self._add_persisted_position(next_id) # TODO Can we call this for just the last position or somehow batch
# _add_persisted_position.
for next_id in next_ids:
self._add_persisted_position(next_id)
def get_current_token(self) -> int: def get_current_token(self) -> int:
return self.get_persisted_upto_position() return self.get_persisted_upto_position()
@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException], exc: Optional[BaseException],
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> bool: ) -> bool:
for i in self.stream_ids: self.id_gen._mark_ids_as_finished(self.stream_ids)
self.id_gen._mark_id_as_finished(i)
self.notifier.notify_replication() self.notifier.notify_replication()

View File

@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
return self.get_success( row = self.get_success(
self.store.db_pool.simple_select_one( self.store.db_pool.simple_select_one(
table + "_current", table + "_current",
{id_col: stat_id}, {id_col: stat_id},
@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
return None if row is None else dict(zip(cols, row))
def _perform_background_initial_update(self) -> None: def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update # Do the initial population of the stats via the background update
self._add_background_updates() self._add_background_updates()

View File

@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertTrue(profile[0] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None: def test_handle_local_profile_change_with_deactivated_user(self) -> None:
# create user # create user
@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory # profile is in directory
profile = self.get_success(self.store._get_user_in_directory(r_user_id)) profile = self.get_success(self.store._get_user_in_directory(r_user_id))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertEqual(profile[0], display_name)
# deactivate user # deactivate user
self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) self.get_success(self.store.set_user_deactivated_status(r_user_id, True))

View File

@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
origin, media_id = self.media_id.split("/") origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
assert info is not None assert info is not None
file_id = info["filesystem_id"] file_id = info.filesystem_id
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
origin, file_id origin, file_id

View File

@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
# quarantining # quarantining
channel = self.make_request( channel = self.make_request(
@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["quarantined_by"]) self.assertTrue(media_info.quarantined_by)
# remove from quarantine # remove from quarantine
channel = self.make_request( channel = self.make_request(
@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None: def test_quarantine_protected_media(self) -> None:
""" """
@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify protection # verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info.safe_from_quarantine)
# quarantining # quarantining
channel = self.make_request( channel = self.make_request(
@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify that is not in quarantine # verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info.quarantined_by)
class ProtectMediaByIDTestCase(_AdminMediaTests): class ProtectMediaByIDTestCase(_AdminMediaTests):
@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info.safe_from_quarantine)
# protect # protect
channel = self.make_request( channel = self.make_request(
@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info.safe_from_quarantine)
# unprotect # unprotect
channel = self.make_request( channel = self.make_request(
@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info.safe_from_quarantine)
class PurgeMediaCacheTestCase(_AdminMediaTests): class PurgeMediaCacheTestCase(_AdminMediaTests):

View File

@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory # is in user directory
profile = self.get_success(self.store._get_user_in_directory(self.other_user)) profile = self.get_success(self.store._get_user_in_directory(self.other_user))
assert profile is not None assert profile is not None
self.assertTrue(profile["display_name"] == "User") self.assertEqual(profile[0], "User")
# Deactivate user # Deactivate user
channel = self.make_request( channel = self.make_request(

View File

@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# #
# Note that we don't have the UI Auth session ID, so just pull out the single # Note that we don't have the UI Auth session ID, so just pull out the single
# row. # row.
ui_auth_data = self.get_success( result = self.get_success(
self.store.db_pool.simple_select_one( self.store.db_pool.simple_select_one_onecol(
"ui_auth_sessions", keyvalues={}, retcols=("clientdict",) "ui_auth_sessions", keyvalues={}, retcol="clientdict"
) )
) )
client_dict = db_to_json(ui_auth_data["clientdict"]) client_dict = db_to_json(result)
self.assertNotIn("new_password", client_dict) self.assertNotIn("new_password", client_dict)
@override_config({"rc_3pid_validation": {"burst_count": 3}}) @override_config({"rc_3pid_validation": {"burst_count": 3}})

View File

@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertLessEqual(det_data.items(), channel.json_body.items()) self.assertLessEqual(det_data.items(), channel.json_body.items())
# Check the `completed` counter has been incremented and pending is 0 # Check the `completed` counter has been incremented and pending is 0
res = self.get_success( pending, completed = self.get_success(
store.db_pool.simple_select_one( store.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["pending", "completed"], retcols=["pending", "completed"],
) )
) )
self.assertEqual(res["completed"], 1) self.assertEqual(completed, 1)
self.assertEqual(res["pending"], 0) self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self) -> None: def test_POST_registration_token_invalid(self) -> None:
@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1["auth"]["type"] = LoginType.DUMMY params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1) self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1 # Check pending=0 and completed=1
res = self.get_success( pending, completed = self.get_success(
store.db_pool.simple_select_one( store.db_pool.simple_select_one(
"registration_tokens", "registration_tokens",
keyvalues={"token": token}, keyvalues={"token": token},
retcols=["pending", "completed"], retcols=["pending", "completed"],
) )
) )
self.assertEqual(res["pending"], 0) self.assertEqual(pending, 0)
self.assertEqual(res["completed"], 1) self.assertEqual(completed, 1)
# Check auth still fails when using token with session2 # Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)

View File

@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not.""" """Given an MXC URI, assert whether it has been purged or not."""
if mxc_uri.server_name == self.hs.config.server.server_name: if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success( found_media = bool(
self.store.get_local_media(mxc_uri.media_id) self.get_success(self.store.get_local_media(mxc_uri.media_id))
) )
else: else:
found_media_dict = self.get_success( found_media = bool(
self.store.get_cached_remote_media( self.get_success(
mxc_uri.server_name, mxc_uri.media_id self.store.get_cached_remote_media(
mxc_uri.server_name, mxc_uri.media_id
)
) )
) )
if expect_purged: if expect_purged:
self.assertIsNone( self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
)
else: else:
self.assertIsNotNone( self.assertTrue(
found_media_dict, found_media,
msg=f"{mxc_uri} unexpectedly purged", msg=f"{mxc_uri} unexpectedly purged",
) )

View File

@ -0,0 +1,117 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, call
from synapse.storage.database import LoggingTransaction
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase
class CacheInvalidationTestCase(HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation(self) -> None:
master_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
self.get_success(
self.store.db_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main
def test_bulk_invalidation_replicates(self) -> None:
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
master_invalidate = Mock()
worker_invalidate = Mock()
self.store._get_cached_user_device.invalidate = master_invalidate
worker = self.make_worker_hs("synapse.app.generic_worker")
worker_ds = worker.get_datastores().main
worker_ds._get_cached_user_device.invalidate = worker_invalidate
keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]
def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)
assert self.store._cache_id_gen is not None
initial_token = self.store._cache_id_gen.get_current_token()
self.get_success(
self.database_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
second_token = self.store._cache_id_gen.get_current_token()
self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
self.get_success(
worker.get_replication_data_handler().wait_for_stream_position(
"master", "caches", second_token
)
)
master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
worker_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)

View File

@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEqual((1, 2, 3), ret)
self.mock_txn.execute.assert_called_once_with( self.mock_txn.execute.assert_called_once_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
) )
@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.assertFalse(ret) self.assertIsNone(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:

View File

@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
) )
def test_get_room(self) -> None: def test_get_room(self) -> None:
res = self.get_success(self.store.get_room(self.room.to_string())) room = self.get_success(self.store.get_room(self.room.to_string()))
assert res is not None assert room is not None
self.assertLessEqual( self.assertTrue(room[0])
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
}.items(),
res.items(),
)
def test_get_room_unknown_room(self) -> None: def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))