Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197)
							parent
							
								
									22b926c284
								
							
						
					
					
						commit
						d58fda99ff
					
				|  | @ -0,0 +1 @@ | |||
| Convert various parts of the codebase to async/await. | ||||
|  | @ -15,7 +15,7 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import List | ||||
| from typing import Dict, List, Union | ||||
| 
 | ||||
| from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json | ||||
|  | @ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
|         # Now return the first `limit` | ||||
|         return notifs[:limit] | ||||
| 
 | ||||
|     def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): | ||||
|     async def get_if_maybe_push_in_range_for_user( | ||||
|         self, user_id: str, min_stream_ordering: int | ||||
|     ) -> bool: | ||||
|         """A fast check to see if there might be something to push for the | ||||
|         user since the given stream ordering. May return false positives. | ||||
| 
 | ||||
|         Useful to know whether to bother starting a pusher on start up or not. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str) | ||||
|             min_stream_ordering (int) | ||||
|             user_id | ||||
|             min_stream_ordering | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[bool]: True if there may be push to process, False if | ||||
|             there definitely isn't. | ||||
|             True if there may be push to process, False if there definitely isn't. | ||||
|         """ | ||||
| 
 | ||||
|         def _get_if_maybe_push_in_range_for_user_txn(txn): | ||||
|  | @ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
|             txn.execute(sql, (user_id, min_stream_ordering)) | ||||
|             return bool(txn.fetchone()) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_if_maybe_push_in_range_for_user", | ||||
|             _get_if_maybe_push_in_range_for_user_txn, | ||||
|         ) | ||||
| 
 | ||||
|     async def add_push_actions_to_staging(self, event_id, user_id_actions): | ||||
|     async def add_push_actions_to_staging( | ||||
|         self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] | ||||
|     ) -> None: | ||||
|         """Add the push actions for the event to the push action staging area. | ||||
| 
 | ||||
|         Args: | ||||
|             event_id (str) | ||||
|             user_id_actions (dict[str, list[dict|str])]): A dictionary mapping | ||||
|                 user_id to list of push actions, where an action can either be | ||||
|                 a string or dict. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred | ||||
|             event_id | ||||
|             user_id_actions: A mapping of user_id to list of push actions, where | ||||
|                 an action can either be a string or dict. | ||||
|         """ | ||||
| 
 | ||||
|         if not user_id_actions: | ||||
|  | @ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
|             "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago | ||||
|         ) | ||||
| 
 | ||||
|     def find_first_stream_ordering_after_ts(self, ts): | ||||
|     async def find_first_stream_ordering_after_ts(self, ts: int) -> int: | ||||
|         """Gets the stream ordering corresponding to a given timestamp. | ||||
| 
 | ||||
|         Specifically, finds the stream_ordering of the first event that was | ||||
|  | @ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
|         relatively slow. | ||||
| 
 | ||||
|         Args: | ||||
|             ts (int): timestamp in millis | ||||
|             ts: timestamp in millis | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[int]: stream ordering of the first event received on/after | ||||
|                 the timestamp | ||||
|             stream ordering of the first event received on/after the timestamp | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "_find_first_stream_ordering_after_ts_txn", | ||||
|             self._find_first_stream_ordering_after_ts_txn, | ||||
|             ts, | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ | |||
| 
 | ||||
| import logging | ||||
| import re | ||||
| from typing import Any, Dict, List, Optional | ||||
| from typing import Any, Dict, List, Optional, Tuple | ||||
| 
 | ||||
| from synapse.api.constants import UserTypes | ||||
| from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError | ||||
|  | @ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|         return is_trial | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_user_by_access_token(self, token): | ||||
|     async def get_user_by_access_token(self, token: str) -> Optional[dict]: | ||||
|         """Get a user from the given access token. | ||||
| 
 | ||||
|         Args: | ||||
|             token (str): The access token of a user. | ||||
|             token: The access token of a user. | ||||
|         Returns: | ||||
|             defer.Deferred: None, if the token did not match, otherwise dict | ||||
|                 including the keys `name`, `is_guest`, `device_id`, `token_id`, | ||||
|                 `valid_until_ms`. | ||||
|             None, if the token did not match, otherwise dict | ||||
|             including the keys `name`, `is_guest`, `device_id`, `token_id`, | ||||
|             `valid_until_ms`. | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_user_by_access_token", self._query_for_auth, token | ||||
|         ) | ||||
| 
 | ||||
|  | @ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return bool(res) if res else False | ||||
| 
 | ||||
|     def set_server_admin(self, user, admin): | ||||
|     async def set_server_admin(self, user: UserID, admin: bool) -> None: | ||||
|         """Sets whether a user is an admin of this homeserver. | ||||
| 
 | ||||
|         Args: | ||||
|             user (UserID): user ID of the user to test | ||||
|             admin (bool): true iff the user is to be a server admin, | ||||
|                 false otherwise. | ||||
|             user: user ID of the user to test | ||||
|             admin: true iff the user is to be a server admin, false otherwise. | ||||
|         """ | ||||
| 
 | ||||
|         def set_server_admin_txn(txn): | ||||
|  | @ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|                 txn, self.get_user_by_id, (user.to_string(),) | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) | ||||
|         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) | ||||
| 
 | ||||
|     def _query_for_auth(self, txn, token): | ||||
|         sql = ( | ||||
|  | @ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|         ) | ||||
|         return True if res == UserTypes.SUPPORT else False | ||||
| 
 | ||||
|     def get_users_by_id_case_insensitive(self, user_id): | ||||
|     async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: | ||||
|         """Gets users that match user_id case insensitively. | ||||
|         Returns a mapping of user_id -> password_hash. | ||||
| 
 | ||||
|         Returns: | ||||
|              A mapping of user_id -> password_hash. | ||||
|         """ | ||||
| 
 | ||||
|         def f(txn): | ||||
|  | @ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|             txn.execute(sql, (user_id,)) | ||||
|             return dict(txn) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) | ||||
|         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) | ||||
| 
 | ||||
|     async def get_user_by_external_id( | ||||
|         self, auth_provider: str, external_id: str | ||||
|  | @ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|         return await self.db_pool.runInteraction("count_users", _count_users) | ||||
| 
 | ||||
|     def count_daily_user_type(self): | ||||
|     async def count_daily_user_type(self) -> Dict[str, int]: | ||||
|         """ | ||||
|         Counts 1) native non guest users | ||||
|                2) native guests users | ||||
|  | @ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|                 results[row[0]] = row[1] | ||||
|             return results | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "count_daily_user_type", _count_daily_user_type | ||||
|         ) | ||||
| 
 | ||||
|  | @ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|         # Convert the integer into a boolean. | ||||
|         return res == 1 | ||||
| 
 | ||||
|     def get_threepid_validation_session( | ||||
|         self, medium, client_secret, address=None, sid=None, validated=True | ||||
|     ): | ||||
|     async def get_threepid_validation_session( | ||||
|         self, | ||||
|         medium: Optional[str], | ||||
|         client_secret: str, | ||||
|         address: Optional[str] = None, | ||||
|         sid: Optional[str] = None, | ||||
|         validated: Optional[bool] = True, | ||||
|     ) -> Optional[Dict[str, Any]]: | ||||
|         """Gets a session_id and last_send_attempt (if available) for a | ||||
|         combination of validation metadata | ||||
| 
 | ||||
|         Args: | ||||
|             medium (str|None): The medium of the 3PID | ||||
|             address (str|None): The address of the 3PID | ||||
|             sid (str|None): The ID of the validation session | ||||
|             client_secret (str): A unique string provided by the client to help identify this | ||||
|             medium: The medium of the 3PID | ||||
|             client_secret: A unique string provided by the client to help identify this | ||||
|                 validation attempt | ||||
|             validated (bool|None): Whether sessions should be filtered by | ||||
|             address: The address of the 3PID | ||||
|             sid: The ID of the validation session | ||||
|             validated: Whether sessions should be filtered by | ||||
|                 whether they have been validated already or not. None to | ||||
|                 perform no filtering | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict|None]: A dict containing the following: | ||||
|             A dict containing the following: | ||||
|                 * address - address of the 3pid | ||||
|                 * medium - medium of the 3pid | ||||
|                 * client_secret - a secret provided by the client for this validation session | ||||
|  | @ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
| 
 | ||||
|             return rows[0] | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_threepid_validation_session", get_threepid_validation_session_txn | ||||
|         ) | ||||
| 
 | ||||
|     def delete_threepid_session(self, session_id): | ||||
|     async def delete_threepid_session(self, session_id: str) -> None: | ||||
|         """Removes a threepid validation session from the database. This can | ||||
|         be done after validation has been performed and whatever action was | ||||
|         waiting on it has been carried out | ||||
| 
 | ||||
|         Args: | ||||
|             session_id (str): The ID of the session to delete | ||||
|             session_id: The ID of the session to delete | ||||
|         """ | ||||
| 
 | ||||
|         def delete_threepid_session_txn(txn): | ||||
|  | @ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore): | |||
|                 keyvalues={"session_id": session_id}, | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "delete_threepid_session", delete_threepid_session_txn | ||||
|         ) | ||||
| 
 | ||||
|  | @ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             desc="add_access_token_to_user", | ||||
|         ) | ||||
| 
 | ||||
|     def register_user( | ||||
|     async def register_user( | ||||
|         self, | ||||
|         user_id, | ||||
|         password_hash=None, | ||||
|         was_guest=False, | ||||
|         make_guest=False, | ||||
|         appservice_id=None, | ||||
|         create_profile_with_displayname=None, | ||||
|         admin=False, | ||||
|         user_type=None, | ||||
|         shadow_banned=False, | ||||
|     ): | ||||
|         user_id: str, | ||||
|         password_hash: Optional[str] = None, | ||||
|         was_guest: bool = False, | ||||
|         make_guest: bool = False, | ||||
|         appservice_id: Optional[str] = None, | ||||
|         create_profile_with_displayname: Optional[str] = None, | ||||
|         admin: bool = False, | ||||
|         user_type: Optional[str] = None, | ||||
|         shadow_banned: bool = False, | ||||
|     ) -> None: | ||||
|         """Attempts to register an account. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): The desired user ID to register. | ||||
|             password_hash (str|None): Optional. The password hash for this user. | ||||
|             was_guest (bool): Optional. Whether this is a guest account being | ||||
|                 upgraded to a non-guest account. | ||||
|             make_guest (boolean): True if the the new user should be guest, | ||||
|                 false to add a regular user account. | ||||
|             appservice_id (str): The ID of the appservice registering the user. | ||||
|             create_profile_with_displayname (unicode): Optionally create a profile for | ||||
|             user_id: The desired user ID to register. | ||||
|             password_hash: Optional. The password hash for this user. | ||||
|             was_guest: Whether this is a guest account being upgraded to a | ||||
|                 non-guest account. | ||||
|             make_guest: True if the the new user should be guest, false to add a | ||||
|                 regular user account. | ||||
|             appservice_id: The ID of the appservice registering the user. | ||||
|             create_profile_with_displayname: Optionally create a profile for | ||||
|                 the user, setting their displayname to the given value | ||||
|             admin (boolean): is an admin user? | ||||
|             user_type (str|None): type of user. One of the values from | ||||
|                 api.constants.UserTypes, or None for a normal user. | ||||
|             shadow_banned (bool): Whether the user is shadow-banned, | ||||
|                 i.e. they may be told their requests succeeded but we ignore them. | ||||
|             admin: is an admin user? | ||||
|             user_type: type of user. One of the values from api.constants.UserTypes, | ||||
|                 or None for a normal user. | ||||
|             shadow_banned: Whether the user is shadow-banned, i.e. they may be | ||||
|                 told their requests succeeded but we ignore them. | ||||
| 
 | ||||
|         Raises: | ||||
|             StoreError if the user_id could not be registered. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "register_user", | ||||
|             self._register_user, | ||||
|             user_id, | ||||
|  | @ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             desc="record_user_external_id", | ||||
|         ) | ||||
| 
 | ||||
|     def user_set_password_hash(self, user_id, password_hash): | ||||
|     async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: | ||||
|         """ | ||||
|         NB. This does *not* evict any cache because the one use for this | ||||
|             removes most of the entries subsequently anyway so it would be | ||||
|  | @ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             ) | ||||
|             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "user_set_password_hash", user_set_password_hash_txn | ||||
|         ) | ||||
| 
 | ||||
|     def user_set_consent_version(self, user_id, consent_version): | ||||
|     async def user_set_consent_version( | ||||
|         self, user_id: str, consent_version: str | ||||
|     ) -> None: | ||||
|         """Updates the user table to record privacy policy consent | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): full mxid of the user to update | ||||
|             consent_version (str): version of the policy the user has consented | ||||
|                 to | ||||
|             user_id: full mxid of the user to update | ||||
|             consent_version: version of the policy the user has consented to | ||||
| 
 | ||||
|         Raises: | ||||
|             StoreError(404) if user not found | ||||
|  | @ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             ) | ||||
|             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("user_set_consent_version", f) | ||||
|         await self.db_pool.runInteraction("user_set_consent_version", f) | ||||
| 
 | ||||
|     def user_set_consent_server_notice_sent(self, user_id, consent_version): | ||||
|     async def user_set_consent_server_notice_sent( | ||||
|         self, user_id: str, consent_version: str | ||||
|     ) -> None: | ||||
|         """Updates the user table to record that we have sent the user a server | ||||
|         notice about privacy policy consent | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): full mxid of the user to update | ||||
|             consent_version (str): version of the policy we have notified the | ||||
|                 user about | ||||
|             user_id: full mxid of the user to update | ||||
|             consent_version: version of the policy we have notified the user about | ||||
| 
 | ||||
|         Raises: | ||||
|             StoreError(404) if user not found | ||||
|  | @ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             ) | ||||
|             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) | ||||
|         await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) | ||||
| 
 | ||||
|     def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): | ||||
|     async def user_delete_access_tokens( | ||||
|         self, | ||||
|         user_id: str, | ||||
|         except_token_id: Optional[str] = None, | ||||
|         device_id: Optional[str] = None, | ||||
|     ) -> List[Tuple[str, int, Optional[str]]]: | ||||
|         """ | ||||
|         Invalidate access tokens belonging to a user | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str):  ID of user the tokens belong to | ||||
|             except_token_id (str): list of access_tokens IDs which should | ||||
|                 *not* be deleted | ||||
|             device_id (str|None):  ID of device the tokens are associated with. | ||||
|             user_id: ID of user the tokens belong to | ||||
|             except_token_id: access_tokens ID which should *not* be deleted | ||||
|             device_id: ID of device the tokens are associated with. | ||||
|                 If None, tokens associated with any device (or no device) will | ||||
|                 be deleted | ||||
|         Returns: | ||||
|             defer.Deferred[list[str, int, str|None, int]]: a list of | ||||
|                 (token, token id, device id) for each of the deleted tokens | ||||
|             A tuple of (token, token id, device id) for each of the deleted tokens | ||||
|         """ | ||||
| 
 | ||||
|         def f(txn): | ||||
|  | @ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
| 
 | ||||
|             return tokens_and_devices | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("user_delete_access_tokens", f) | ||||
|         return await self.db_pool.runInteraction("user_delete_access_tokens", f) | ||||
| 
 | ||||
|     def delete_access_token(self, access_token): | ||||
|     async def delete_access_token(self, access_token: str) -> None: | ||||
|         def f(txn): | ||||
|             self.db_pool.simple_delete_one_txn( | ||||
|                 txn, table="access_tokens", keyvalues={"token": access_token} | ||||
|  | @ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|                 txn, self.get_user_by_access_token, (access_token,) | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("delete_access_token", f) | ||||
|         await self.db_pool.runInteraction("delete_access_token", f) | ||||
| 
 | ||||
|     @cached() | ||||
|     async def is_guest(self, user_id: str) -> bool: | ||||
|  | @ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             desc="get_users_pending_deactivation", | ||||
|         ) | ||||
| 
 | ||||
|     def validate_threepid_session(self, session_id, client_secret, token, current_ts): | ||||
|     async def validate_threepid_session( | ||||
|         self, session_id: str, client_secret: str, token: str, current_ts: int | ||||
|     ) -> Optional[str]: | ||||
|         """Attempt to validate a threepid session using a token | ||||
| 
 | ||||
|         Args: | ||||
|             session_id (str): The id of a validation session | ||||
|             client_secret (str): A unique string provided by the client to | ||||
|                 help identify this validation attempt | ||||
|             token (str): A validation token | ||||
|             current_ts (int): The current unix time in milliseconds. Used for | ||||
|                 checking token expiry status | ||||
|             session_id: The id of a validation session | ||||
|             client_secret: A unique string provided by the client to help identify | ||||
|                 this validation attempt | ||||
|             token: A validation token | ||||
|             current_ts: The current unix time in milliseconds. Used for checking | ||||
|                 token expiry status | ||||
| 
 | ||||
|         Raises: | ||||
|             ThreepidValidationError: if a matching validation token was not found or has | ||||
|                 expired | ||||
| 
 | ||||
|         Returns: | ||||
|             deferred str|None: A str representing a link to redirect the user | ||||
|             to if there is one. | ||||
|             A str representing a link to redirect the user to if there is one. | ||||
|         """ | ||||
| 
 | ||||
|         # Insert everything into a transaction in order to run atomically | ||||
|  | @ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             return next_link | ||||
| 
 | ||||
|         # Return next_link if it exists | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "validate_threepid_session_txn", validate_threepid_session_txn | ||||
|         ) | ||||
| 
 | ||||
|     def start_or_continue_validation_session( | ||||
|     async def start_or_continue_validation_session( | ||||
|         self, | ||||
|         medium, | ||||
|         address, | ||||
|         session_id, | ||||
|         client_secret, | ||||
|         send_attempt, | ||||
|         next_link, | ||||
|         token, | ||||
|         token_expires, | ||||
|     ): | ||||
|         medium: str, | ||||
|         address: str, | ||||
|         session_id: str, | ||||
|         client_secret: str, | ||||
|         send_attempt: int, | ||||
|         next_link: Optional[str], | ||||
|         token: str, | ||||
|         token_expires: int, | ||||
|     ) -> None: | ||||
|         """Creates a new threepid validation session if it does not already | ||||
|         exist and associates a new validation token with it | ||||
| 
 | ||||
|         Args: | ||||
|             medium (str): The medium of the 3PID | ||||
|             address (str): The address of the 3PID | ||||
|             session_id (str): The id of this validation session | ||||
|             client_secret (str): A unique string provided by the client to | ||||
|                 help identify this validation attempt | ||||
|             send_attempt (int): The latest send_attempt on this session | ||||
|             next_link (str|None): The link to redirect the user to upon | ||||
|                 successful validation | ||||
|             token (str): The validation token | ||||
|             token_expires (int): The timestamp for which after the token | ||||
|                 will no longer be valid | ||||
|             medium: The medium of the 3PID | ||||
|             address: The address of the 3PID | ||||
|             session_id: The id of this validation session | ||||
|             client_secret: A unique string provided by the client to help | ||||
|                 identify this validation attempt | ||||
|             send_attempt: The latest send_attempt on this session | ||||
|             next_link: The link to redirect the user to upon successful validation | ||||
|             token: The validation token | ||||
|             token_expires: The timestamp for which after the token will no | ||||
|                 longer be valid | ||||
|         """ | ||||
| 
 | ||||
|         def start_or_continue_validation_session_txn(txn): | ||||
|  | @ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|                 }, | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "start_or_continue_validation_session", | ||||
|             start_or_continue_validation_session_txn, | ||||
|         ) | ||||
| 
 | ||||
|     def cull_expired_threepid_validation_tokens(self): | ||||
|     async def cull_expired_threepid_validation_tokens(self) -> None: | ||||
|         """Remove threepid validation tokens with expiry dates that have passed""" | ||||
| 
 | ||||
|         def cull_expired_threepid_validation_tokens_txn(txn, ts): | ||||
|  | @ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): | |||
|             DELETE FROM threepid_validation_token WHERE | ||||
|             expires < ? | ||||
|             """ | ||||
|             return txn.execute(sql, (ts,)) | ||||
|             txn.execute(sql, (ts,)) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         await self.db_pool.runInteraction( | ||||
|             "cull_expired_threepid_validation_tokens", | ||||
|             cull_expired_threepid_validation_tokens_txn, | ||||
|             self.clock.time_msec(), | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set | ||||
| from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.events import EventBase | ||||
|  | @ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|             ) | ||||
| 
 | ||||
|     @cached(max_entries=100000, iterable=True) | ||||
|     def get_users_in_room(self, room_id: str): | ||||
|         return self.db_pool.runInteraction( | ||||
|     async def get_users_in_room(self, room_id: str) -> List[str]: | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_users_in_room", self.get_users_in_room_txn, room_id | ||||
|         ) | ||||
| 
 | ||||
|  | @ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|         return [r[0] for r in txn] | ||||
| 
 | ||||
|     @cached(max_entries=100000) | ||||
|     def get_room_summary(self, room_id: str): | ||||
|     async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: | ||||
|         """ Get the details of a room roughly suitable for use by the room | ||||
|         summary extension to /sync. Useful when lazy loading room members. | ||||
|         Args: | ||||
|             room_id: The room ID to query | ||||
|         Returns: | ||||
|             Deferred[dict[str, MemberSummary]: | ||||
|                 dict of membership states, pointing to a MemberSummary named tuple. | ||||
|             dict of membership states, pointing to a MemberSummary named tuple. | ||||
|         """ | ||||
| 
 | ||||
|         def _get_room_summary_txn(txn): | ||||
|  | @ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
| 
 | ||||
|             return res | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_room_summary", _get_room_summary_txn | ||||
|         ) | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: | ||||
|     async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: | ||||
|         """Get all the rooms the *local* user is invited to. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user ID. | ||||
| 
 | ||||
|         Returns: | ||||
|             A awaitable list of RoomsForUser. | ||||
|             A list of RoomsForUser. | ||||
|         """ | ||||
| 
 | ||||
|         return self.get_rooms_for_local_user_where_membership_is( | ||||
|         return await self.get_rooms_for_local_user_where_membership_is( | ||||
|             user_id, [Membership.INVITE] | ||||
|         ) | ||||
| 
 | ||||
|  | @ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|         return results | ||||
| 
 | ||||
|     @cached(max_entries=500000, iterable=True) | ||||
|     def get_rooms_for_user_with_stream_ordering(self, user_id: str): | ||||
|     async def get_rooms_for_user_with_stream_ordering( | ||||
|         self, user_id: str | ||||
|     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: | ||||
|         """Returns a set of room_ids the user is currently joined to. | ||||
| 
 | ||||
|         If a remote user only returns rooms this server is currently | ||||
|  | @ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|             user_id | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns | ||||
|             the rooms the user is in currently, along with the stream ordering | ||||
|             of the most recent join for that user and room. | ||||
|             Returns the rooms the user is in currently, along with the stream | ||||
|             ordering of the most recent join for that user and room. | ||||
|         """ | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_rooms_for_user_with_stream_ordering", | ||||
|             self._get_rooms_for_user_with_stream_ordering_txn, | ||||
|             user_id, | ||||
|         ) | ||||
| 
 | ||||
|     def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): | ||||
|     def _get_rooms_for_user_with_stream_ordering_txn( | ||||
|         self, txn, user_id: str | ||||
|     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: | ||||
|         # We use `current_state_events` here and not `local_current_membership` | ||||
|         # as a) this gets called with remote users and b) this only gets called | ||||
|         # for rooms the server is participating in. | ||||
|  | @ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|             """ | ||||
| 
 | ||||
|         txn.execute(sql, (user_id, Membership.JOIN)) | ||||
|         results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) | ||||
| 
 | ||||
|         return results | ||||
|         return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) | ||||
| 
 | ||||
|     async def get_users_server_still_shares_room_with( | ||||
|         self, user_ids: Collection[str] | ||||
|  | @ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|         return count == 0 | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_forgotten_rooms_for_user(self, user_id: str): | ||||
|     async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: | ||||
|         """Gets all rooms the user has forgotten. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id | ||||
|             user_id: The user ID to query the rooms of. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[set[str]] | ||||
|             The forgotten rooms. | ||||
|         """ | ||||
| 
 | ||||
|         def _get_forgotten_rooms_for_user_txn(txn): | ||||
|  | @ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|             txn.execute(sql, (user_id,)) | ||||
|             return {row[0] for row in txn if row[1] == 0} | ||||
| 
 | ||||
|         return self.db_pool.runInteraction( | ||||
|         return await self.db_pool.runInteraction( | ||||
|             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn | ||||
|         ) | ||||
| 
 | ||||
|  | @ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
|     def __init__(self, database: DatabasePool, db_conn, hs): | ||||
|         super(RoomMemberStore, self).__init__(database, db_conn, hs) | ||||
| 
 | ||||
|     def forget(self, user_id: str, room_id: str): | ||||
|     async def forget(self, user_id: str, room_id: str) -> None: | ||||
|         """Indicate that user_id wishes to discard history for room_id.""" | ||||
| 
 | ||||
|         def f(txn): | ||||
|  | @ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
|                 txn, self.get_forgotten_rooms_for_user, (user_id,) | ||||
|             ) | ||||
| 
 | ||||
|         return self.db_pool.runInteraction("forget_membership", f) | ||||
|         await self.db_pool.runInteraction("forget_membership", f) | ||||
| 
 | ||||
| 
 | ||||
| class _JoinedHostsCache(object): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke