Convert calls of async database methods to async (#8166)
							parent
							
								
									c9fa696ea2
								
							
						
					
					
						commit
						9b7ac03af3
					
				|  | @ -0,0 +1 @@ | |||
| Convert various parts of the codebase to async/await. | ||||
|  | @ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module. | |||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| from synapse.federation.units import Transaction | ||||
| from synapse.logging.utils import log_function | ||||
| from synapse.types import JsonDict | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -49,15 +51,15 @@ class TransactionActions(object): | |||
|         return self.store.get_received_txn_response(transaction.transaction_id, origin) | ||||
| 
 | ||||
|     @log_function | ||||
|     def set_response(self, origin, transaction, code, response): | ||||
|     async def set_response( | ||||
|         self, origin: str, transaction: Transaction, code: int, response: JsonDict | ||||
|     ) -> None: | ||||
|         """ Persist how we responded to a transaction. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred | ||||
|         """ | ||||
|         if not transaction.transaction_id: | ||||
|         transaction_id = transaction.transaction_id  # type: ignore | ||||
|         if not transaction_id: | ||||
|             raise RuntimeError("Cannot persist a transaction with no transaction_id") | ||||
| 
 | ||||
|         return self.store.set_received_txn_response( | ||||
|             transaction.transaction_id, origin, code, response | ||||
|         await self.store.set_received_txn_response( | ||||
|             transaction_id, origin, code, response | ||||
|         ) | ||||
|  |  | |||
|  | @ -107,9 +107,7 @@ class Transaction(JsonEncodedObject): | |||
|         if "edus" in kwargs and not kwargs["edus"]: | ||||
|             del kwargs["edus"] | ||||
| 
 | ||||
|         super(Transaction, self).__init__( | ||||
|             transaction_id=transaction_id, pdus=pdus, **kwargs | ||||
|         ) | ||||
|         super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def create_new(pdus, **kwargs): | ||||
|  |  | |||
|  | @ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore( | |||
|             return result.get("state") | ||||
|         return None | ||||
| 
 | ||||
|     def set_appservice_state(self, service, state): | ||||
|     async def set_appservice_state(self, service, state) -> None: | ||||
|         """Set the application service state. | ||||
| 
 | ||||
|         Args: | ||||
|             service(ApplicationService): The service whose state to set. | ||||
|             state(ApplicationServiceState): The connectivity state to apply. | ||||
|         Returns: | ||||
|             An Awaitable which resolves when the state was set successfully. | ||||
|         """ | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             "application_services_state", {"as_id": service.id}, {"state": state} | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return {row["user_id"] for row in rows} | ||||
| 
 | ||||
|     def mark_remote_user_device_cache_as_stale(self, user_id: str): | ||||
|     async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: | ||||
|         """Records that the server has reason to believe the cache of the devices | ||||
|         for the remote users is out of date. | ||||
|         """ | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="device_lists_remote_resync", | ||||
|             keyvalues={"user_id": user_id}, | ||||
|             values={}, | ||||
|  |  | |||
|  | @ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|             desc="remove_room_from_summary", | ||||
|         ) | ||||
| 
 | ||||
|     def upsert_group_category(self, group_id, category_id, profile, is_public): | ||||
|     async def upsert_group_category( | ||||
|         self, | ||||
|         group_id: str, | ||||
|         category_id: str, | ||||
|         profile: Optional[JsonDict], | ||||
|         is_public: Optional[bool], | ||||
|     ) -> None: | ||||
|         """Add/update room category for group | ||||
|         """ | ||||
|         insertion_values = {} | ||||
|  | @ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|         else: | ||||
|             update_values["is_public"] = is_public | ||||
| 
 | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="group_room_categories", | ||||
|             keyvalues={"group_id": group_id, "category_id": category_id}, | ||||
|             values=update_values, | ||||
|  | @ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|             desc="remove_group_category", | ||||
|         ) | ||||
| 
 | ||||
|     def upsert_group_role(self, group_id, role_id, profile, is_public): | ||||
|     async def upsert_group_role( | ||||
|         self, | ||||
|         group_id: str, | ||||
|         role_id: str, | ||||
|         profile: Optional[JsonDict], | ||||
|         is_public: Optional[bool], | ||||
|     ) -> None: | ||||
|         """Add/remove user role | ||||
|         """ | ||||
|         insertion_values = {} | ||||
|  | @ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|         else: | ||||
|             update_values["is_public"] = is_public | ||||
| 
 | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="group_roles", | ||||
|             keyvalues={"group_id": group_id, "role_id": role_id}, | ||||
|             values=update_values, | ||||
|  | @ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|             desc="remove_user_from_summary", | ||||
|         ) | ||||
| 
 | ||||
|     def add_group_invite(self, group_id, user_id): | ||||
|     async def add_group_invite(self, group_id: str, user_id: str) -> None: | ||||
|         """Record that the group server has invited a user | ||||
|         """ | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="group_invites", | ||||
|             values={"group_id": group_id, "user_id": user_id}, | ||||
|             desc="add_group_invite", | ||||
|  | @ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore): | |||
|             "remove_user_from_group", _remove_user_from_group_txn | ||||
|         ) | ||||
| 
 | ||||
|     def add_room_to_group(self, group_id, room_id, is_public): | ||||
|         return self.db_pool.simple_insert( | ||||
|     async def add_room_to_group( | ||||
|         self, group_id: str, room_id: str, is_public: bool | ||||
|     ) -> None: | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="group_rooms", | ||||
|             values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, | ||||
|             desc="add_room_to_group", | ||||
|  |  | |||
|  | @ -140,22 +140,28 @@ class KeyStore(SQLBaseStore): | |||
|         for i in invalidations: | ||||
|             invalidate((i,)) | ||||
| 
 | ||||
|     def store_server_keys_json( | ||||
|         self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes | ||||
|     ): | ||||
|     async def store_server_keys_json( | ||||
|         self, | ||||
|         server_name: str, | ||||
|         key_id: str, | ||||
|         from_server: str, | ||||
|         ts_now_ms: int, | ||||
|         ts_expires_ms: int, | ||||
|         key_json_bytes: bytes, | ||||
|     ) -> None: | ||||
|         """Stores the JSON bytes for a set of keys from a server | ||||
|         The JSON should be signed by the originating server, the intermediate | ||||
|         server, and by this server. Updates the value for the | ||||
|         (server_name, key_id, from_server) triplet if one already existed. | ||||
|         Args: | ||||
|             server_name (str): The name of the server. | ||||
|             key_id (str): The identifer of the key this JSON is for. | ||||
|             from_server (str): The server this JSON was fetched from. | ||||
|             ts_now_ms (int): The time now in milliseconds. | ||||
|             ts_valid_until_ms (int): The time when this json stops being valid. | ||||
|             key_json (bytes): The encoded JSON. | ||||
|             server_name: The name of the server. | ||||
|             key_id: The identifer of the key this JSON is for. | ||||
|             from_server: The server this JSON was fetched from. | ||||
|             ts_now_ms: The time now in milliseconds. | ||||
|             ts_valid_until_ms: The time when this json stops being valid. | ||||
|             key_json_bytes: The encoded JSON. | ||||
|         """ | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="server_keys_json", | ||||
|             keyvalues={ | ||||
|                 "server_name": server_name, | ||||
|  |  | |||
|  | @ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|             desc="get_local_media", | ||||
|         ) | ||||
| 
 | ||||
|     def store_local_media( | ||||
|     async def store_local_media( | ||||
|         self, | ||||
|         media_id, | ||||
|         media_type, | ||||
|  | @ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|         media_length, | ||||
|         user_id, | ||||
|         url_cache=None, | ||||
|     ): | ||||
|         return self.db_pool.simple_insert( | ||||
|     ) -> None: | ||||
|         await self.db_pool.simple_insert( | ||||
|             "local_media_repository", | ||||
|             { | ||||
|                 "media_id": media_id, | ||||
|  | @ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
| 
 | ||||
|         return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) | ||||
| 
 | ||||
|     def store_url_cache( | ||||
|     async def store_url_cache( | ||||
|         self, url, response_code, etag, expires_ts, og, media_id, download_ts | ||||
|     ): | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             "local_media_repository_url_cache", | ||||
|             { | ||||
|                 "url": url, | ||||
|  | @ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|             desc="get_local_media_thumbnails", | ||||
|         ) | ||||
| 
 | ||||
|     def store_local_thumbnail( | ||||
|     async def store_local_thumbnail( | ||||
|         self, | ||||
|         media_id, | ||||
|         thumbnail_width, | ||||
|  | @ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|         thumbnail_method, | ||||
|         thumbnail_length, | ||||
|     ): | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             "local_media_repository_thumbnails", | ||||
|             { | ||||
|                 "media_id": media_id, | ||||
|  | @ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|             desc="get_cached_remote_media", | ||||
|         ) | ||||
| 
 | ||||
|     def store_cached_remote_media( | ||||
|     async def store_cached_remote_media( | ||||
|         self, | ||||
|         origin, | ||||
|         media_id, | ||||
|  | @ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|         upload_name, | ||||
|         filesystem_id, | ||||
|     ): | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             "remote_media_cache", | ||||
|             { | ||||
|                 "media_origin": origin, | ||||
|  | @ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|             desc="get_remote_media_thumbnails", | ||||
|         ) | ||||
| 
 | ||||
|     def store_remote_media_thumbnail( | ||||
|     async def store_remote_media_thumbnail( | ||||
|         self, | ||||
|         origin, | ||||
|         media_id, | ||||
|  | @ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
|         thumbnail_method, | ||||
|         thumbnail_length, | ||||
|     ): | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             "remote_media_cache_thumbnails", | ||||
|             { | ||||
|                 "media_origin": origin, | ||||
|  |  | |||
|  | @ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore | |||
| 
 | ||||
| 
 | ||||
| class OpenIdStore(SQLBaseStore): | ||||
|     def insert_open_id_token(self, token, ts_valid_until_ms, user_id): | ||||
|         return self.db_pool.simple_insert( | ||||
|     async def insert_open_id_token( | ||||
|         self, token: str, ts_valid_until_ms: int, user_id: str | ||||
|     ) -> None: | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="open_id_tokens", | ||||
|             values={ | ||||
|                 "token": token, | ||||
|  |  | |||
|  | @ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore): | |||
|             desc="get_from_remote_profile_cache", | ||||
|         ) | ||||
| 
 | ||||
|     def create_profile(self, user_localpart): | ||||
|         return self.db_pool.simple_insert( | ||||
|     async def create_profile(self, user_localpart: str) -> None: | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="profiles", values={"user_id": user_localpart}, desc="create_profile" | ||||
|         ) | ||||
| 
 | ||||
|  | @ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore): | |||
| 
 | ||||
| 
 | ||||
| class ProfileStore(ProfileWorkerStore): | ||||
|     def add_remote_profile_cache(self, user_id, displayname, avatar_url): | ||||
|     async def add_remote_profile_cache( | ||||
|         self, user_id: str, displayname: str, avatar_url: str | ||||
|     ) -> None: | ||||
|         """Ensure we are caching the remote user's profiles. | ||||
| 
 | ||||
|         This should only be called when `is_subscribed_remote_profile_for_user` | ||||
|         would return true for the user. | ||||
|         """ | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="remote_profile_cache", | ||||
|             keyvalues={"user_id": user_id}, | ||||
|             values={ | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ | |||
| 
 | ||||
| import logging | ||||
| import re | ||||
| from typing import Any, Awaitable, Dict, List, Optional | ||||
| from typing import Any, Dict, List, Optional | ||||
| 
 | ||||
| from synapse.api.constants import UserTypes | ||||
| from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError | ||||
|  | @ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|             desc="user_delete_threepids", | ||||
|         ) | ||||
| 
 | ||||
|     def add_user_bound_threepid(self, user_id, medium, address, id_server): | ||||
|     async def add_user_bound_threepid( | ||||
|         self, user_id: str, medium: str, address: str, id_server: str | ||||
|     ): | ||||
|         """The server proxied a bind request to the given identity server on | ||||
|         behalf of the given user. We need to remember this in case the user | ||||
|         asks us to unbind the threepid. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str) | ||||
|             medium (str) | ||||
|             address (str) | ||||
|             id_server (str) | ||||
| 
 | ||||
|         Returns: | ||||
|             Awaitable | ||||
|             user_id | ||||
|             medium | ||||
|             address | ||||
|             id_server | ||||
|         """ | ||||
|         # We need to use an upsert, in case they user had already bound the | ||||
|         # threepid | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="user_threepid_id_server", | ||||
|             keyvalues={ | ||||
|                 "user_id": user_id, | ||||
|  | @ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
| 
 | ||||
|         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | ||||
| 
 | ||||
|     def record_user_external_id( | ||||
|     async def record_user_external_id( | ||||
|         self, auth_provider: str, external_id: str, user_id: str | ||||
|     ) -> Awaitable: | ||||
|     ) -> None: | ||||
|         """Record a mapping from an external user id to a mxid | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             external_id: id on that system | ||||
|             user_id: complete mxid that it is mapped to | ||||
|         """ | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="user_external_ids", | ||||
|             values={ | ||||
|                 "auth_provider": auth_provider, | ||||
|  | @ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
| 
 | ||||
|         return res if res else False | ||||
| 
 | ||||
|     def add_user_pending_deactivation(self, user_id): | ||||
|     async def add_user_pending_deactivation(self, user_id: str) -> None: | ||||
|         """ | ||||
|         Adds a user to the table of users who need to be parted from all the rooms they're | ||||
|         in | ||||
|         """ | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             "users_pending_deactivation", | ||||
|             values={"user_id": user_id}, | ||||
|             desc="add_user_pending_deactivation", | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions | |||
| from synapse.storage._base import SQLBaseStore, db_to_json | ||||
| from synapse.storage.database import DatabasePool, LoggingTransaction | ||||
| from synapse.storage.databases.main.search import SearchStore | ||||
| from synapse.types import ThirdPartyInstanceID | ||||
| from synapse.types import JsonDict, ThirdPartyInstanceID | ||||
| from synapse.util import json_encoder | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
|  | @ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
| 
 | ||||
|         return self.db_pool.runInteraction("get_rooms", f) | ||||
| 
 | ||||
|     def add_event_report( | ||||
|         self, room_id, event_id, user_id, reason, content, received_ts | ||||
|     ): | ||||
|     async def add_event_report( | ||||
|         self, | ||||
|         room_id: str, | ||||
|         event_id: str, | ||||
|         user_id: str, | ||||
|         reason: str, | ||||
|         content: JsonDict, | ||||
|         received_ts: int, | ||||
|     ) -> None: | ||||
|         next_id = self._event_reports_id_gen.get_next() | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="event_reports", | ||||
|             values={ | ||||
|                 "id": next_id, | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ | |||
| 
 | ||||
| import logging | ||||
| from itertools import chain | ||||
| from typing import Tuple | ||||
| from typing import Any, Dict, Tuple | ||||
| 
 | ||||
| from twisted.internet.defer import DeferredLock | ||||
| 
 | ||||
|  | @ -222,11 +222,11 @@ class StatsStore(StateDeltasStore): | |||
|             desc="stats_incremental_position", | ||||
|         ) | ||||
| 
 | ||||
|     def update_room_state(self, room_id, fields): | ||||
|     async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: | ||||
|         """ | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             fields (dict[str:Any]) | ||||
|             room_id | ||||
|             fields | ||||
|         """ | ||||
| 
 | ||||
|         # For whatever reason some of the fields may contain null bytes, which | ||||
|  | @ -244,7 +244,7 @@ class StatsStore(StateDeltasStore): | |||
|             if field and "\0" in field: | ||||
|                 fields[col] = None | ||||
| 
 | ||||
|         return self.db_pool.simple_upsert( | ||||
|         await self.db_pool.simple_upsert( | ||||
|             table="room_stats_state", | ||||
|             keyvalues={"room_id": room_id}, | ||||
|             values=fields, | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json | |||
| from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.storage._base import SQLBaseStore, db_to_json | ||||
| from synapse.storage.database import DatabasePool | ||||
| from synapse.types import JsonDict | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| 
 | ||||
| db_binary_type = memoryview | ||||
|  | @ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore): | |||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|     def set_received_txn_response(self, transaction_id, origin, code, response_dict): | ||||
|         """Persist the response we returened for an incoming transaction, and | ||||
|     async def set_received_txn_response( | ||||
|         self, transaction_id: str, origin: str, code: int, response_dict: JsonDict | ||||
|     ) -> None: | ||||
|         """Persist the response we returned for an incoming transaction, and | ||||
|         should return for subsequent transactions with the same transaction_id | ||||
|         and origin. | ||||
| 
 | ||||
|         Args: | ||||
|             txn | ||||
|             transaction_id (str) | ||||
|             origin (str) | ||||
|             code (int) | ||||
|             response_json (str) | ||||
|             transaction_id: The incoming transaction ID. | ||||
|             origin: The origin server. | ||||
|             code: The response code. | ||||
|             response_dict: The response, to be encoded into JSON. | ||||
|         """ | ||||
| 
 | ||||
|         return self.db_pool.simple_insert( | ||||
|         await self.db_pool.simple_insert( | ||||
|             table="received_transactions", | ||||
|             values={ | ||||
|                 "transaction_id": transaction_id, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke