From d58fda99ffd29c15254788476fb8775370eebec5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 11:34:50 -0400 Subject: [PATCH 01/31] Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197) --- changelog.d/8197.misc | 1 + .../databases/main/event_push_actions.py | 38 ++- .../storage/databases/main/registration.py | 238 +++++++++--------- synapse/storage/databases/main/roommember.py | 52 ++-- 4 files changed, 169 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8197.misc diff --git a/changelog.d/8197.misc b/changelog.d/8197.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8197.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e8834b2162..384c39f4d7 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List +from typing import Dict, List, Union from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Now return the first `limit` return notifs[:limit] - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + async def get_if_maybe_push_in_range_for_user( + self, user_id: str, min_stream_ordering: int + ) -> bool: """A fast check to see if there might be something to push for the user since the given stream ordering. May return false positives. Useful to know whether to bother starting a pusher on start up or not. Args: - user_id (str) - min_stream_ordering (int) + user_id + min_stream_ordering Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. + True if there may be push to process, False if there definitely isn't. """ def _get_if_maybe_push_in_range_for_user_txn(txn): @@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) - async def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging( + self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] + ) -> None: """Add the push actions for the event to the push action staging area. Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred + event_id + user_id_actions: A mapping of user_id to list of push actions, where + an action can either be a string or dict. """ if not user_id_actions: @@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) - def find_first_stream_ordering_after_ts(self, ts): + async def find_first_stream_ordering_after_ts(self, ts: int) -> int: """Gets the stream ordering corresponding to a given timestamp. Specifically, finds the stream_ordering of the first event that was @@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): relatively slow. Args: - ts (int): timestamp in millis + ts: timestamp in millis Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp + stream ordering of the first event received on/after the timestamp """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 12689f4308..000b8ccff4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore): return is_trial @cached() - def get_user_by_access_token(self, token): + async def get_user_by_access_token(self, token: str) -> Optional[dict]: """Get a user from the given access token. Args: - token (str): The access token of a user. + token: The access token of a user. Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore): return bool(res) if res else False - def set_server_admin(self, user, admin): + async def set_server_admin(self, user: UserID, admin: bool) -> None: """Sets whether a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. + user: user ID of the user to test + admin: true iff the user is to be a server admin, false otherwise. """ def set_server_admin_txn(txn): @@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_user_by_id, (user.to_string(),) ) - return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( @@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore): ) return True if res == UserTypes.SUPPORT else False - def get_users_by_id_case_insensitive(self, user_id): + async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. + + Returns: + A mapping of user_id -> password_hash. """ def f(txn): @@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("count_users", _count_users) - def count_daily_user_type(self): + async def count_daily_user_type(self) -> Dict[str, int]: """ Counts 1) native non guest users 2) native guests users @@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_user_type", _count_daily_user_type ) @@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore): # Convert the integer into a boolean. return res == 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): + async def get_threepid_validation_session( + self, + medium: Optional[str], + client_secret: str, + address: Optional[str] = None, + sid: Optional[str] = None, + validated: Optional[bool] = True, + ) -> Optional[Dict[str, Any]]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this + medium: The medium of the 3PID + client_secret: A unique string provided by the client to help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by + address: The address of the 3PID + sid: The ID of the validation session + validated: Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - Deferred[dict|None]: A dict containing the following: + A dict containing the following: * address - address of the 3pid * medium - medium of the 3pid * client_secret - a secret provided by the client for this validation session @@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) - def delete_threepid_session(self, session_id): + async def delete_threepid_session(self, session_id: str) -> None: """Removes a threepid validation session from the database. This can be done after validation has been performed and whatever action was waiting on it has been carried out Args: - session_id (str): The ID of the session to delete + session_id: The ID of the session to delete """ def delete_threepid_session_txn(txn): @@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore): keyvalues={"session_id": session_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) @@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) - def register_user( + async def register_user( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Attempts to register an account. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Whether this is a guest account being upgraded to a + non-guest account. + make_guest: True if the the new user should be guest, false to add a + regular user account. + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - shadow_banned (bool): Whether the user is shadow-banned, - i.e. they may be told their requests succeeded but we ignore them. + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, + or None for a normal user. + shadow_banned: Whether the user is shadow-banned, i.e. they may be + told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. - - Returns: - Deferred """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "register_user", self._register_user, user_id, @@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="record_user_external_id", ) - def user_set_password_hash(self, user_id, password_hash): + async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "user_set_password_hash", user_set_password_hash_txn ) - def user_set_consent_version(self, user_id, consent_version): + async def user_set_consent_version( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to + user_id: full mxid of the user to update + consent_version: version of the policy the user has consented to Raises: StoreError(404) if user not found @@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_version", f) + await self.db_pool.runInteraction("user_set_consent_version", f) - def user_set_consent_server_notice_sent(self, user_id, consent_version): + async def user_set_consent_server_notice_sent( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record that we have sent the user a server notice about privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about + user_id: full mxid of the user to update + consent_version: version of the policy we have notified the user about Raises: StoreError(404) if user not found @@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) + await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + async def user_delete_access_tokens( + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, + ) -> List[Tuple[str, int, Optional[str]]]: """ Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_tokens ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens + A tuple of (token, token id, device id) for each of the deleted tokens """ def f(txn): @@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.db_pool.runInteraction("user_delete_access_tokens", f) + return await self.db_pool.runInteraction("user_delete_access_tokens", f) - def delete_access_token(self, access_token): + async def delete_access_token(self, access_token: str) -> None: def f(txn): self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} @@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.db_pool.runInteraction("delete_access_token", f) + await self.db_pool.runInteraction("delete_access_token", f) @cached() async def is_guest(self, user_id: str) -> bool: @@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="get_users_pending_deactivation", ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): + async def validate_threepid_session( + self, session_id: str, client_secret: str, token: str, current_ts: int + ) -> Optional[str]: """Attempt to validate a threepid session using a token Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status + session_id: The id of a validation session + client_secret: A unique string provided by the client to help identify + this validation attempt + token: A validation token + current_ts: The current unix time in milliseconds. Used for checking + token expiry status Raises: ThreepidValidationError: if a matching validation token was not found or has expired Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. + A str representing a link to redirect the user to if there is one. """ # Insert everything into a transaction in order to run atomically @@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) - def start_or_continue_validation_session( + async def start_or_continue_validation_session( self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): + medium: str, + address: str, + session_id: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str], + token: str, + token_expires: int, + ) -> None: """Creates a new threepid validation session if it does not already exist and associates a new validation token with it Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid + medium: The medium of the 3PID + address: The address of the 3PID + session_id: The id of this validation session + client_secret: A unique string provided by the client to help + identify this validation attempt + send_attempt: The latest send_attempt on this session + next_link: The link to redirect the user to upon successful validation + token: The validation token + token_expires: The timestamp for which after the token will no + longer be valid """ def start_or_continue_validation_session_txn(txn): @@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) - def cull_expired_threepid_validation_tokens(self): + async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" def cull_expired_threepid_validation_tokens_txn(txn, ts): @@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): DELETE FROM threepid_validation_token WHERE expires < ? """ - return txn.execute(sql, (ts,)) + txn.execute(sql, (ts,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 161edbeccb..c4ce5c17c2 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id: str): - return self.db_pool.runInteraction( + async def get_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id: str): + async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. + dict of membership states, pointing to a MemberSummary named tuple. """ def _get_room_summary_txn(txn): @@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) + return await self.db_pool.runInteraction( + "get_room_summary", _get_room_summary_txn + ) @cached() - def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: """Get all the rooms the *local* user is invited to. Args: user_id: The user ID. Returns: - A awaitable list of RoomsForUser. + A list of RoomsForUser. """ - return self.get_rooms_for_local_user_where_membership_is( + return await self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) @@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id: str): + async def get_rooms_for_user_with_stream_ordering( + self, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently @@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. + Returns the rooms the user is in currently, along with the stream + ordering of the most recent join for that user and room. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): + def _get_rooms_for_user_with_stream_ordering_txn( + self, txn, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results + return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id: str): + async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: """Gets all rooms the user has forgotten. Args: - user_id + user_id: The user ID to query the rooms of. Returns: - Deferred[set[str]] + The forgotten rooms. """ def _get_forgotten_rooms_for_user_txn(txn): @@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id: str, room_id: str): + async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db_pool.runInteraction("forget_membership", f) + await self.db_pool.runInteraction("forget_membership", f) class _JoinedHostsCache(object): From 3b4556cf87666bb6f40d89c8c7fff42d336237b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Aug 2020 17:12:45 +0100 Subject: [PATCH 02/31] Fix `wait_for_stream_position` for multiple waiters. (#8196) This fixes a bug where having multiple callers waiting on the same stream and position will cause it to try and compare two deferreds, which fails (due to the sorted list having an entry of `Tuple[int, Deferred]`). --- changelog.d/8196.misc | 1 + synapse/python_dependencies.py | 4 +++- synapse/replication/tcp/client.py | 6 ++---- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8196.misc diff --git a/changelog.d/8196.misc b/changelog.d/8196.misc new file mode 100644 index 0000000000..c42baf0e56 --- /dev/null +++ b/changelog.d/8196.misc @@ -0,0 +1 @@ +Fix `wait_for_stream_position` to allow multiple waiters on same stream ID. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index dd77a44b8d..2d995ec456 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -66,7 +66,9 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", "prometheus_client>=0.0.18,<0.9.0", - # we use attr.validators.deep_iterable, which arrived in 19.1.0 + # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: + # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 + # is out in November.) "attrs>=19.1.0", "netaddr>=0.7.18", "Jinja2>=2.9", diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fcf8ebf1e7..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -219,9 +218,8 @@ class ReplicationDataHandler: waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): From b4826d6eb18eeb05c973882efd4c0f5d68359449 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 28 Aug 2020 17:39:14 +0100 Subject: [PATCH 03/31] Fix incorrect return signature --- synapse/storage/databases/main/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 000b8ccff4..01f20c03c2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -99,7 +99,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: """Get the expiration timestamp for the account bearing a given user ID. Args: From d2ac767de2ce895a965fb2fcdcb883636f19a5c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 16:47:11 -0400 Subject: [PATCH 04/31] Convert ReadWriteLock to async/await. (#8202) --- changelog.d/8202.misc | 1 + synapse/handlers/pagination.py | 49 ++++++++++++++++++---------------- synapse/util/async_helpers.py | 16 +++++------ tests/util/test_rwlock.py | 6 +++-- 4 files changed, 39 insertions(+), 33 deletions(-) create mode 100644 changelog.d/8202.misc diff --git a/changelog.d/8202.misc b/changelog.d/8202.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8202.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ac3418d69d..5a1aa7d830 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.filtering import Filter from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Requester, RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -247,15 +250,16 @@ class PaginationHandler(object): ) return purge_id - async def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history( + self, purge_id: str, room_id: str, token: str, delete_local_events: bool + ) -> None: """Carry out a history purge on a room. Args: - purge_id (str): The id for this purge - room_id (str): The room to purge from - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as - remote ones + purge_id: The id for this purge + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones """ self._purges_in_progress_by_room.add(room_id) try: @@ -291,9 +295,9 @@ class PaginationHandler(object): """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> None: """Purge the given room from the database""" - with (await self.pagination_lock.write(room_id)): + with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) @@ -307,23 +311,22 @@ class PaginationHandler(object): async def get_messages( self, - requester, - room_id=None, - pagin_config=None, - as_client_event=True, - event_filter=None, - ): + requester: Requester, + room_id: Optional[str] = None, + pagin_config: Optional[PaginationConfig] = None, + as_client_event: bool = True, + event_filter: Optional[Filter] = None, + ) -> Dict[str, Any]: """Get messages in a room. Args: - requester (Requester): The user requesting messages. - room_id (str): The room they want messages from. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. - as_client_event (bool): True to get events in client-server format. - event_filter (Filter): Filter to apply to results or None + requester: The user requesting messages. + room_id: The room they want messages from. + pagin_config: The pagination config rules to apply, if any. + as_client_event: True to get events in client-server format. + event_filter: Filter to apply to results or None Returns: - dict: Pagination API results + Pagination API results """ user_id = requester.user.to_string() @@ -343,7 +346,7 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (await self.pagination_lock.read(room_id)): + with await self.pagination_lock.read(room_id): ( membership, member_event_id, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f562770922..dfefbd996d 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -20,6 +20,7 @@ from contextlib import contextmanager from typing import Dict, Sequence, Set, Union import attr +from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -338,11 +339,11 @@ class Linearizer(object): class ReadWriteLock(object): - """A deferred style read write lock. + """An async read write lock. Example: - with (yield read_write_lock.read("test_key")): + with await read_write_lock.read("test_key"): # do some work """ @@ -365,8 +366,7 @@ class ReadWriteLock(object): # Latest writer queued self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] - @defer.inlineCallbacks - def read(self, key): + async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -376,7 +376,8 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) + if curr_writer: + await make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -388,8 +389,7 @@ class ReadWriteLock(object): return _ctx_manager() - @defer.inlineCallbacks - def write(self, key): + async def write(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -405,7 +405,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index bd32e2cee7..d3dea3b52a 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.async_helpers import ReadWriteLock @@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase): rwlock.read(key), # 5 rwlock.write(key), # 6 ] + ds = [defer.ensureDeferred(d) for d in ds] self._assert_called_before_not_after(ds, 2) @@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase): with ds[6].result: pass - d = rwlock.write(key) + d = defer.ensureDeferred(rwlock.write(key)) self.assertTrue(d.called) with d.result: pass - d = rwlock.read(key) + d = defer.ensureDeferred(rwlock.read(key)) self.assertTrue(d.called) with d.result: pass From 8027166dd546811316e7c8395deaf929110c2f3a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sat, 29 Aug 2020 00:05:25 +0100 Subject: [PATCH 05/31] Add a comment about _LimitedHostnameResolver --- synapse/app/_base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2b2cd795e0..a43dc5b2c9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we can run out of file descriptors and infinite loop if we attempt to do too many DNS queries at once + + XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless + you explicitly install twisted.names as the resolver; rather it uses a GAIResolver + backed by the reactor's default threadpool (which is limited to 10 threads). So + (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't + understand why we would run out of FDs if we did too many lookups at once. + -- richvdh 2020/08/29 """ new_resolver = _LimitedHostnameResolver( reactor.nameResolver, max_dns_requests_in_flight From 45e8f7726f24d98a9d3fa06ea52ae960cc1d8689 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Sat, 29 Aug 2020 00:14:17 +0100 Subject: [PATCH 06/31] Rename `get_e2e_device_keys` to better reflect its purpose (#8205) ... and to show that it does something slightly different to `_get_e2e_device_keys_txn`. `include_all_devices` and `include_deleted_devices` were never used (and `include_deleted_devices` was broken, since that would cause `None`s in the result which were not handled in the loop below. Add some typing too. --- changelog.d/8205.misc | 1 + synapse/handlers/e2e_keys.py | 4 ++-- .../storage/databases/main/end_to_end_keys.py | 20 ++++++------------- tests/storage/test_end_to_end_keys.py | 8 +++++--- 4 files changed, 14 insertions(+), 19 deletions(-) create mode 100644 changelog.d/8205.misc diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc new file mode 100644 index 0000000000..fb8fd83278 --- /dev/null +++ b/changelog.d/8205.misc @@ -0,0 +1 @@ + Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d8def45e38..dfd1c78549 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -353,7 +353,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -734,7 +734,7 @@ class E2eKeysHandler(object): # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = await self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index af0b85e2c9..50ecddf7fa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -33,17 +34,12 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): @trace - async def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -54,11 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, ) # Build the result structure, un-jsonify the results, and add the diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 261bf5b08b..3fc4bb13b6 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) + self.store.get_e2e_device_keys_for_cs_api( + (("user1", "device1"), ("user2", "device2")) + ) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) From aa07c37cf0a3b812e6aa1bb2d97d543e6925c8e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 12:41:21 +0100 Subject: [PATCH 07/31] Move and rename `get_devices_with_keys_by_user` (#8204) * Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- changelog.d/8204.misc | 1 + synapse/handlers/device.py | 4 +- synapse/replication/slave/storage/devices.py | 3 ++ synapse/storage/databases/main/__init__.py | 3 ++ synapse/storage/databases/main/devices.py | 52 ++---------------- .../storage/databases/main/end_to_end_keys.py | 53 ++++++++++++++++++- 6 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8204.misc diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8204.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60de..ee4666337a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler): return result async def on_federation_query_user_devices(self, user_id): - stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 596c72eb92..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f..e6536c8456 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a2..e8379c73c4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore): update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore): }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 50ecddf7fa..fb3b1f94de 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -33,6 +34,51 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] @@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): From 318245eaa6d37a27ca72168356198fdd90abfbb7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:16:58 -0400 Subject: [PATCH 08/31] Do not install setuptools 50.0. (#8212) This is due to compatibility issues with old Python versions. --- INSTALL.md | 2 +- changelog.d/8212.bugfix | 1 + synapse/python_dependencies.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8212.bugfix diff --git a/INSTALL.md b/INSTALL.md index 22f7b7c029..bdb7769fe9 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -73,7 +73,7 @@ mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip -pip install --upgrade setuptools +pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions pip install matrix-synapse ``` diff --git a/changelog.d/8212.bugfix b/changelog.d/8212.bugfix new file mode 100644 index 0000000000..0f8c0aed92 --- /dev/null +++ b/changelog.d/8212.bugfix @@ -0,0 +1 @@ +Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2d995ec456..d666f22674 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -74,6 +74,10 @@ REQUIREMENTS = [ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", + # setuptools is required by a variety of dependencies, unfortunately version + # 50.0 is incompatible with older Python versions, see + # https://github.com/pypa/setuptools/issues/2352 + "setuptools!=50.0", ] CONDITIONAL_REQUIREMENTS = { From bbb3c8641ca1214702dbddd88acfe81cc4fc44ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2020 13:36:25 +0100 Subject: [PATCH 09/31] Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203) This is so that we can use it for the backfill events stream. --- changelog.d/8203.misc | 1 + synapse/storage/util/id_generators.py | 39 +++++++--- tests/storage/test_id_generators.py | 105 ++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8203.misc diff --git a/changelog.d/8203.misc b/changelog.d/8203.misc new file mode 100644 index 0000000000..9fe2224aaa --- /dev/null +++ b/changelog.d/8203.misc @@ -0,0 +1 @@ +Make `MultiWriterIDGenerator` work for streams that use negative values. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b27a4843d0..9f3d23f0a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -185,6 +185,8 @@ class MultiWriterIdGenerator: id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + positive: Whether the IDs are positive (true) or negative (false). + When using negative IDs we go backwards from -1 to -2, -3, etc. """ def __init__( @@ -196,13 +198,19 @@ class MultiWriterIdGenerator: instance_column: str, id_column: str, sequence_name: str, + positive: bool = True, ): self._db = db self._instance_name = instance_name + self._positive = positive + self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. self._lock = threading.Lock() + # Note: If we are a negative stream then we still store all the IDs as + # positive to make life easier for us, and simply negate the IDs when we + # return them. self._current_positions = self._load_current_ids( db_conn, table, instance_column, id_column ) @@ -233,13 +241,16 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: + # If positive stream aggregate via MAX. For negative stream use MIN + # *and* negate the result to get a positive number. sql = """ - SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s GROUP BY %(instance)s """ % { "instance": instance_column, "id": id_column, "table": table, + "agg": "MAX" if self._positive else "-MIN", } cur = db_conn.cursor() @@ -269,15 +280,16 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token_for_writer(self._instance_name) < next_id - with self._lock: + assert self._current_positions.get(self._instance_name, 0) < next_id + self._unfinished_ids.add(next_id) @contextlib.contextmanager def manager(): try: - yield next_id + # Multiply by the return factor so that the ID has correct sign. + yield self._return_factor * next_id finally: self._mark_id_as_finished(next_id) @@ -296,15 +308,15 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than any ID we've already # seen. If not, then the sequence and table have got out of sync # somehow. - assert max(self.get_positions().values(), default=0) < min(next_ids) - with self._lock: + assert max(self._current_positions.values(), default=0) < min(next_ids) + self._unfinished_ids.update(next_ids) @contextlib.contextmanager def manager(): try: - yield next_ids + yield [self._return_factor * i for i in next_ids] finally: for i in next_ids: self._mark_id_as_finished(i) @@ -327,7 +339,7 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) - return next_id + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the @@ -359,20 +371,25 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. """ with self._lock: - return dict(self._current_positions) + return { + name: self._return_factor * i + for name, i in self._current_positions.items() + } def advance(self, instance_name: str, new_id: int): """Advance the postion of the named writer to the given ID, if greater than existing entry. """ + new_id *= self._return_factor + with self._lock: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) @@ -390,7 +407,7 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._persisted_upto_position + return self._return_factor * self._persisted_upto_position def _add_persisted_position(self, new_id: int): """Record that we have persisted a position. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 14ce21c786..f0a8e32f1e 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # We assume that so long as `get_next` does correctly advance the # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + + +class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): + """Tests MultiWriterIdGenerator that produce *negative* stream IDs. + """ + + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + positive=False, + ) + + return self.get_success(self.db_pool.runWithConnection(_create)) + + def _insert_row(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + id_gen = self._create_id_generator() + + with self.get_success(id_gen.get_next()) as stream_id: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -1}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) + self.assertEqual(id_gen.get_persisted_upto_position(), -1) + + with self.get_success(id_gen.get_next_mult(3)) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -4}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(id_gen.get_persisted_upto_position(), -4) + + # Test loading from DB by creating a second ID gen + second_id_gen = self._create_id_generator() + + self.assertEqual(second_id_gen.get_positions(), {"master": -4}) + self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) + + def test_multiple_instance(self): + """Tests that having multiple instances that get advanced over + federation works corretly. + """ + id_gen_1 = self._create_id_generator("first") + id_gen_2 = self._create_id_generator("second") + + with self.get_success(id_gen_1.get_next()) as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) + + with self.get_success(id_gen_2.get_next()) as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: [PATCH 10/31] Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 ++++--- synapse/handlers/message.py | 13 ++--- synapse/handlers/room_member.py | 12 +---- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 +- synapse/storage/databases/main/openid.py | 8 ++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++-- synapse/storage/databases/main/room.py | 49 +++++++++++-------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++--- synapse/storage/databases/main/ui_auth.py | 4 +- .../databases/main/user_erasure_store.py | 8 +-- tests/test_utils/event_injection.py | 7 ++- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd..7878cd7044 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ class EventBuilder(object): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ class EventCreationHandler(object): event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ class EventCreationHandler(object): self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a672..d2e0685e9e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c339..717df97301 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc09..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: From 5bf8e5f55b49f9e46a7fe7d7872e6b16d38bffd3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:15:22 -0400 Subject: [PATCH 11/31] Convert the well known resolver to async (#8214) --- changelog.d/8214.misc | 1 + mypy.ini | 1 + .../federation/matrix_federation_agent.py | 4 +- .../http/federation/well_known_resolver.py | 57 ++++++++++--------- .../test_matrix_federation_agent.py | 24 ++++++-- 5 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8214.misc diff --git a/changelog.d/8214.misc b/changelog.d/8214.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8214.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/mypy.ini b/mypy.ini index 4213e31b03..21c6f523a0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,7 @@ files = synapse/handlers/saml_handler.py, synapse/handlers/sync.py, synapse/handlers/ui_auth, + synapse/http/federation/well_known_resolver.py, synapse/http/server.py, synapse/http/site.py, synapse/logging/, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 369bf9c2fc..782d39d4ca 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -134,8 +134,8 @@ class MatrixFederationAgent(object): and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.hostname + well_known_result = yield defer.ensureDeferred( + self._well_known_resolver.get_well_known(parsed_uri.hostname) ) delegated_server = well_known_result.delegated_server diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index f794315deb..cdb6bec56e 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -16,6 +16,7 @@ import logging import random import time +from typing import Callable, Dict, Optional, Tuple import attr @@ -23,6 +24,7 @@ from twisted.internet import defer from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -99,15 +101,14 @@ class WellKnownResolver(object): self._well_known_agent = RedirectAgent(agent) self.user_agent = user_agent - @defer.inlineCallbacks - def get_well_known(self, server_name): + async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: """Attempt to fetch and parse a .well-known file for the given server Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Returns: - Deferred[WellKnownLookupResult]: The result of the lookup + The result of the lookup """ try: prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( @@ -124,7 +125,9 @@ class WellKnownResolver(object): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._fetch_well_known(server_name) + result, cache_period = await self._fetch_well_known( + server_name + ) # type: Tuple[Optional[bytes], float] except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -153,18 +156,17 @@ class WellKnownResolver(object): return WellKnownLookupResult(delegated_server=result) - @defer.inlineCallbacks - def _fetch_well_known(self, server_name): + async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]: """Actually fetch and parse a .well-known, without checking the cache Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Raises: _FetchWellKnownFailure if we fail to lookup a result Returns: - Deferred[Tuple[bytes,int]]: The lookup result and cache period. + The lookup result and cache period. """ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) @@ -172,7 +174,7 @@ class WellKnownResolver(object): # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - response, body = yield self._make_well_known_request( + response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known ) @@ -215,20 +217,20 @@ class WellKnownResolver(object): return result, cache_period - @defer.inlineCallbacks - def _make_well_known_request(self, server_name, retry): + async def _make_well_known_request( + self, server_name: bytes, retry: bool + ) -> Tuple[IResponse, bytes]: """Make the well known request. This will retry the request if requested and it fails (with unable to connect or receives a 5xx error). Args: - server_name (bytes) - retry (bool): Whether to retry the request if it fails. + server_name: name of the server, from the requested url + retry: Whether to retry the request if it fails. Returns: - Deferred[tuple[IResponse, bytes]] Returns the response object and - body. Response may be a non-200 response. + Returns the response object and body. Response may be a non-200 response. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") @@ -243,12 +245,12 @@ class WellKnownResolver(object): logger.info("Fetching %s", uri_str) try: - response = yield make_deferred_yieldable( + response = await make_deferred_yieldable( self._well_known_agent.request( b"GET", uri, headers=Headers(headers) ) ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -265,21 +267,24 @@ class WellKnownResolver(object): logger.info("Error fetching %s: %s. Retrying", uri_str, e) # Sleep briefly in the hopes that they come back up - yield self._clock.sleep(0.5) + await self._clock.sleep(0.5) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + max_age = cache_controls[b"max-age"] + if max_age: + try: + return int(max_age) + except ValueError: + pass expires = headers.getRawHeaders(b"expires") if expires is not None: @@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 69945a8f98..eb78ab412a 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) From 54f8d73c005cf0401d05fc90e857da253f9d1168 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:21:48 -0400 Subject: [PATCH 12/31] Convert additional databases to async/await (#8199) --- changelog.d/8199.misc | 1 + synapse/storage/databases/main/__init__.py | 50 +++++---- synapse/storage/databases/main/devices.py | 38 +++---- .../storage/databases/main/events_worker.py | 48 ++++---- .../storage/databases/main/purge_events.py | 30 ++--- synapse/storage/databases/main/receipts.py | 14 ++- synapse/storage/databases/main/relations.py | 103 ++++++++---------- 7 files changed, 147 insertions(+), 137 deletions(-) create mode 100644 changelog.d/8199.misc diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8199.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index e6536c8456..99890ffbf3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -294,16 +294,16 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - def count_daily_users(self): + async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) - def count_monthly_users(self): + async def count_monthly_users(self) -> int: """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different @@ -311,7 +311,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -330,15 +330,15 @@ class DataStore( (count,) = txn.fetchone() return count - def count_r30_users(self): + async def count_r30_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart - Returns counts globaly for a given user as well as breaking - by platform + Returns: + A mapping of counts globally as well as broken out by platform. """ def _count_r30_users(txn): @@ -411,7 +411,7 @@ class DataStore( return results - return self.db_pool.runInteraction("count_r30_users", _count_r30_users) + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -421,7 +421,7 @@ class DataStore( today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 - def generate_user_daily_visits(self): + async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis """ @@ -476,7 +476,7 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -500,22 +500,28 @@ class DataStore( desc="get_users", ) - def get_users_paginate( - self, start, limit, user_id=None, name=None, guests=True, deactivated=False - ): + async def get_users_paginate( + self, + start: int, + limit: int, + user_id: Optional[str] = None, + name: Optional[str] = None, + guests: bool = True, + deactivated: bool = False, + ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - user_id (string): search for user_id. ignored if name is not None - name (string): search for local part of user_id or display name - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users + start: start number to begin the query from + limit: number of rows to retrieve + user_id: search for user_id. ignored if name is not None + name: search for local part of user_id or display name + guests: whether to in include guest users + deactivated: whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]], int + A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): @@ -558,7 +564,7 @@ class DataStore( users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_paginate_txn", get_users_paginate_txn ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8379c73c4..a29157d979 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore): return results - def _get_last_device_update_for_remote_user( + async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int - ): + ) -> int: def f(txn): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id @@ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) + return await self.db_pool.runInteraction( + "get_last_device_update_for_remote_user", f + ) - def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): + async def mark_as_sent_devices_by_remote( + self, destination: str, stream_id: int + ) -> None: """Mark that updates have successfully been sent to the destination. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user. """ @@ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) @@ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - def update_remote_device_list_cache_entry( + async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: int - ): + ) -> None: """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's @@ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: ID of decivice being updated content: new data on this device stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache( + async def update_remote_device_list_cache( self, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's @@ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: User to update device list for devices: list of device objects supplied over federation stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e6247d682d..a7a73cc3d8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - def _maybe_redact_event_row(self, original_ev, redactions, event_map): + def _maybe_redact_event_row( + self, + original_ev: EventBase, + redactions: Iterable[str], + event_map: Dict[str, EventBase], + ) -> Optional[EventBase]: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. + original_ev: The original event. + redactions: list of event ids of potential redaction events + event_map: other events which have been fetched, in which we can + look up the redaaction events. Map from event id to event. Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. + If the event should be redacted, a pruned event object. Otherwise, None. """ if original_ev.type == "m.room.create": # we choose to ignore redactions of m.room.create events. @@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore): row = txn.fetchone() return row[0] if row else 0 - def get_current_state_event_counts(self, room_id): + async def get_current_state_event_counts(self, room_id: str) -> int: """ Gets the current number of state events in a room. Args: - room_id (str) + room_id: The room ID to query. Returns: - Deferred[int] + The current number of state events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() - def get_all_new_forward_event_rows(self, last_id, current_id, limit): + async def get_all_new_forward_event_rows( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore): current_id: the maximum stream_id to return up to limit: the maximum number of rows to return - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) - def get_ex_outlier_stream_rows(self, last_id, current_id): + async def get_ex_outlier_stream_rows( + self, last_id: int, current_id: int + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) @@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_next_event_to_expire(self): + async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. - Returns: Deferred[Optional[Tuple[str, int]]] + Returns: A tuple containing the event ID as its first element and an expiry timestamp as its second one, if there's at least one row in the event_expiry table. None otherwise. @@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore): return txn.fetchone() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 3526b6fd66..ea833829ae 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Tuple +from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -25,25 +25,24 @@ logger = logging.getLogger(__name__) class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> Set[int]: """Deletes room history before a certain point Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): + room_id: + token: A topological token to delete events before + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + The set of state groups that are referenced by deleted events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): return referenced_state_groups - def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> List[int]: """Deletes all record of a room Args: - room_id (str) + room_id Returns: - Deferred[List[int]]: The list of state groups to delete. + The list of state groups to delete. """ - - return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) + return await self.db_pool.runInteraction( + "purge_room", self._purge_room_txn, room_id + ) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 436f22ad2d..4a0d5a320e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_users_sent_receipts_between(self, last_id: int, current_id: int): + async def get_users_sent_receipts_between( + self, last_id: int, current_id: int + ) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: - Deferred[List[str]] + The list of users. """ if last_id == current_id: @@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return [r[0] for r in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) @@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db_pool.runInteraction( + async def insert_graph_receipt( + self, room_id, receipt_type, user_id, event_ids, data + ): + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index a9ceffc20e..5cd61547f7 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,38 +34,33 @@ logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) - def get_relations_for_event( + async def get_relations_for_event( self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[RelationPaginationToken] = None, + to_token: Optional[RelationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. + event_id: Fetch events that relate to this event ID. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. """ where_clause = ["relates_to_id = ?"] @@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @cached(tree=True) - def get_aggregation_groups_for_event( + async def get_aggregation_groups_for_event( self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + event_type: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[AggregationPaginationToken] = None, + to_token: Optional[AggregationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore): on an event. Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or + event_id: Fetch events that relate to this event ID. + event_type: Only fetch events with this event type, if given. + limit: Only fetch the `limit` groups. + direction: Whether to fetch the highest count first (`"b"`) or the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. """ where_clause = ["relates_to_id = ?", "relation_type = ?"] @@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + async def has_user_annotated_event( + self, parent_id: str, event_type: str, aggregation_key: str, sender: str + ) -> bool: """Check if a user has already annotated an event with the same key (e.g. already liked an event). Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation + parent_id: The event being annotated + event_type: The event type of the annotation + aggregation_key: The aggregation key of the annotation + sender: The sender of the annotation Returns: - Deferred[bool] + True if the event is already annotated. """ sql = """ @@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) From 5615eb5cb48d63df15c391fe395c8740dc4af017 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 16:02:17 +0100 Subject: [PATCH 13/31] Rename `_get_e2e_device_keys_txn` (#8222) ... to `_get_e2e_device_keys_and_signatures_txn`, to better reflect what it does. --- changelog.d/8222.misc | 1 + synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/end_to_end_keys.py | 10 ++++++---- 3 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8222.misc diff --git a/changelog.d/8222.misc b/changelog.d/8222.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8222.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a29157d979..11956cc48e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -256,8 +256,8 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = ( await self.db_pool.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fb3b1f94de..1ee062e3c4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -51,7 +51,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) -> Tuple[int, List[JsonDict]]: now_stream_id = self.get_device_stream_token() - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)]) if devices: user_devices = devices[user_id] @@ -96,7 +96,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, + query_list, ) # Build the result structure, un-jsonify the results, and add the @@ -120,9 +122,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return rv @trace - def _get_e2e_device_keys_txn( + def _get_e2e_device_keys_and_signatures_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): + ) -> Dict[str, Dict[str, Optional[Dict]]]: set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) From 7d103a594e10f4040096bf38716c9cdddf42eca0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 11:03:49 -0400 Subject: [PATCH 14/31] Convert appservice code to async/await. (#8207) --- changelog.d/8207.misc | 1 + synapse/appservice/api.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) create mode 100644 changelog.d/8207.misc diff --git a/changelog.d/8207.misc b/changelog.d/8207.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8207.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e72a0b9ac0..bb6fa8299a 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,18 +14,20 @@ # limitations under the License. import logging import urllib +from typing import TYPE_CHECKING, Optional from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.appservice import ApplicationService + logger = logging.getLogger(__name__) sent_transactions_counter = Counter( @@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_3pe to %s threw exception %s", uri, ex) return [] - def get_3pe_protocol(self, service, protocol): + async def get_3pe_protocol( + self, service: "ApplicationService", protocol: str + ) -> Optional[JsonDict]: if service.url is None: return {} - @defer.inlineCallbacks - def _get(): + async def _get() -> Optional[JsonDict]: uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, urllib.parse.quote(protocol), ) try: - info = yield defer.ensureDeferred(self.get_json(uri, {})) + info = await self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): logger.warning( @@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient): return None key = (service.id, protocol) - return self.protocol_meta_cache.wrap(key, _get) + return await self.protocol_meta_cache.wrap(key, _get) async def push_bulk(self, service, events, txn_id=None): if service.url is None: From 37db6252b7ee0f3e9798a561e2919a67299e08f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 11:04:17 -0400 Subject: [PATCH 15/31] Convert additional databases to async/await part 3 (#8201) --- changelog.d/8201.misc | 1 + synapse/storage/background_updates.py | 4 +- .../storage/databases/main/account_data.py | 59 ++++++++++--------- .../storage/databases/main/end_to_end_keys.py | 54 ++++++++++------- .../databases/main/media_repository.py | 31 ++++++---- synapse/storage/databases/main/search.py | 15 +++-- .../storage/databases/main/user_directory.py | 44 ++++++++------ 7 files changed, 121 insertions(+), 87 deletions(-) create mode 100644 changelog.d/8201.misc diff --git a/changelog.d/8201.misc b/changelog.d/8201.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8201.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 0db900fa0e..67a89cd51a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -433,7 +433,7 @@ class BackgroundUpdater(object): "background_updates", keyvalues={"update_name": update_name} ) - def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress(self, update_name: str, progress: dict): """Update the progress of a background update Args: @@ -441,7 +441,7 @@ class BackgroundUpdater(object): progress: The progress of the update. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 04042a2c98..4436b1a83d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,9 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple - -from twisted.internet import defer +from typing import Dict, List, Optional, Tuple from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore): raise NotImplementedError() @cached() - def get_account_data_for_user(self, user_id): + async def get_account_data_for_user( + self, user_id: str + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: - user_id(str): The user to get the account_data for. + user_id: The user to get the account_data for. Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. + A 2-tuple of a dict of global account_data and a dict mapping from + room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): @@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore): return global_account_data, by_room - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore): return None @cached(num_args=2) - def get_account_data_for_room(self, user_id, room_id): + async def get_account_data_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. Returns: - A deferred dict of the room account_data + A dict of the room account_data """ def get_account_data_for_room_txn(txn): @@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: db_to_json(row["content"]) for row in rows } - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) - def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + async def get_account_data_for_room_and_type( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - account_data_type (str): The account data type to get. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. + account_data_type: The account data type to get. Returns: - A deferred of the room account_data for that type, or None if - there isn't any set. + The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): @@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore): return db_to_json(content_json) if content_json else None - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore): "get_updated_room_account_data", get_updated_room_account_data_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id): + async def get_updated_account_data_for_user( + self, user_id: str, stream_id: int + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: - user_id(str): The user to get the account_data for. - stream_id(int): The point in the stream since which to get updates + user_id: The user to get the account_data for. + stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. @@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return defer.succeed(({}, {})) + return ({}, {}) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore): return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id: int): + async def _update_max_stream_id(self, next_id: int) -> None: """Update the max stream_id Args: @@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.db_pool.runInteraction("update_account_data_max_stream_id", _update) + await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1ee062e3c4..5a7de44b33 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -34,13 +34,15 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): - def get_e2e_device_keys_for_federation_query(self, user_id: str): + async def get_e2e_device_keys_for_federation_query( + self, user_id: str + ) -> Tuple[int, List[JsonDict]]: """Get all devices (with any device keys) for a user Returns: - Deferred which resolves to (stream_id, devices) + (stream_id, devices) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_device_keys_for_federation_query", self._get_e2e_device_keys_for_federation_query_txn, user_id, @@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): + async def count_e2e_one_time_keys( + self, user_id: str, device_id: str + ) -> Dict[str, int]: """ Count the number of one time keys the server has for a device Returns: - Dict mapping from algorithm to number of keys for that algorithm. + A mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): @@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result[algorithm] = key_count return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_e2e_one_time_keys", _count_e2e_one_time_keys ) @@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): list_name="user_ids", num_args=1, ) - def _get_bare_e2e_cross_signing_keys_bulk( + async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: List[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this @@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore): the signatures for the calling user need to be fetched. Args: - user_ids (list[str]): the users whose keys are being requested + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A mapping from user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, or + their user ID will map to None. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, @@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + async def set_e2e_device_keys( + self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict + ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ @@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_e2e_device_keys", _set_e2e_device_keys_txn ) - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" + async def claim_e2e_one_time_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + """Take a list of one time keys out of the database. + + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ @trace def _claim_e2e_one_time_keys(txn): @@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - def delete_e2e_keys_by_device(self, user_id, device_id): + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn): log_kv( { @@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 3919ecad69..86557d5512 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) - def get_url_cache(self, url, ts): + async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. @@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) ) - return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) + return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts @@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, local_media, remote_media, time_ms): + async def update_cached_last_access_time( + self, + local_media: Iterable[str], + remote_media: Iterable[Tuple[str, str]], + time_ms: int, + ): """Updates the last access time of the given media Args: - local_media (iterable[str]): Set of media_ids - remote_media (iterable[(str, str)]): Set of (server_name, media_id) + local_media: Set of media_ids + remote_media: Set of (server_name, media_id) time_ms: Current time in milliseconds """ @@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) @@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) - def delete_remote_media(self, media_origin, media_id): + async def delete_remote_media(self, media_origin: str, media_id: str) -> None: def delete_remote_media_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_remote_media", delete_remote_media_txn ) - def get_expired_url_cache(self, now_ts): + async def get_expired_url_cache(self, now_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository_url_cache" " WHERE expires_ts < ?" @@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_expired_url_cache", _get_expired_url_cache_txn ) @@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "delete_url_cache", _delete_url_cache_txn ) - def get_url_cache_media_before(self, before_ts): + async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository" " WHERE created_ts < ? AND url_cache IS NOT NULL" @@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 7f8d1880e5..f01cf2fd02 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,9 +16,10 @@ import logging import re from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Set from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore): "count": count, } - def _find_highlights_in_postgres(self, search_query, events): + async def _find_highlights_in_postgres( + self, search_query: str, events: List[EventBase] + ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlight the matching parts. Args: - search_query (str) - events (list): A list of events + search_query + events: A list of events Returns: - deferred : A set of strings. + A set of strings. """ def f(txn): @@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore): return highlight_words - return self.db_pool.runInteraction("_find_highlights", f) + return await self.db_pool.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a9f2e93614..1e96ae7828 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + async def update_profile_in_user_dir( + self, user_id: str, display_name: str, avatar_url: str + ) -> None: """ Update or add a user's profile in the user directory. """ @@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def add_users_who_share_private_room(self, room_id, user_id_tuples): + async def add_users_who_share_private_room( + self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + room_id + user_id_tuples: iterable of 2-tuple of user IDs. """ def _add_users_who_share_room_txn(txn): @@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) - def add_users_in_public_rooms(self, room_id, user_ids): + async def add_users_in_public_rooms( + self, room_id: str, user_ids: Iterable[str] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_ids (list[str]) + room_id + user_ids """ def _add_users_in_public_rooms_txn(txn): @@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) - def delete_all_from_user_dir(self): + async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory """ @@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryStore, self).__init__(database, db_conn, hs) - def remove_from_user_dir(self, user_id): + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} @@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_from_user_dir", _remove_from_user_dir_txn ) @@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return user_ids - def remove_user_who_share_room(self, user_id, room_id): + async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None: """ Deletes entries in the users_who_share_*_rooms table. The first user should be a local user. Args: - user_id (str) - room_id (str) + user_id + room_id """ def _remove_user_who_share_room_txn(txn): @@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) From b5133dd97f693ca213b30f4f3e874e9ab3958ea7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 1 Sep 2020 16:31:59 +0100 Subject: [PATCH 16/31] Explain better what GDPR-erased means (#8189) Fixes https://github.com/matrix-org/synapse/issues/8185 --- changelog.d/8189.doc | 1 + docs/admin_api/user_admin_api.rst | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8189.doc diff --git a/changelog.d/8189.doc b/changelog.d/8189.doc new file mode 100644 index 0000000000..800ff89dc5 --- /dev/null +++ b/changelog.d/8189.doc @@ -0,0 +1 @@ +Explain better what GDPR-erased means when deactivating a user. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d6e3194cda..e21c78a9c6 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -214,9 +214,11 @@ Deactivate Account This API deactivates an account. It removes active access tokens, resets the password, and deletes third-party IDs (to prevent the user requesting a -password reset). It can also mark the user as GDPR-erased (stopping their data -from distributed further, and deleting it entirely if there are no other -references to it). +password reset). + +It can also mark the user as GDPR-erased. This means messages sent by the +user will still be visible by anyone that was in the room when these messages +were sent, but hidden from users joining the room afterwards. The api is:: From b939251c37d748a4be6346eb27bd5fdfaff17738 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 13:02:41 -0400 Subject: [PATCH 17/31] Fix errors when updating the user directory with invalid data (#8223) --- changelog.d/8223.bugfix | 1 + synapse/handlers/profile.py | 6 ++++++ synapse/handlers/user_directory.py | 8 +++++++- synapse/storage/databases/main/user_directory.py | 5 +++++ 4 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8223.bugfix diff --git a/changelog.d/8223.bugfix b/changelog.d/8223.bugfix new file mode 100644 index 0000000000..60655ce3e1 --- /dev/null +++ b/changelog.d/8223.bugfix @@ -0,0 +1 @@ +Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 96c9d6bab4..0cb8fad89a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -161,6 +161,9 @@ class BaseProfileHandler(BaseHandler): Codes.FORBIDDEN, ) + if not isinstance(new_displayname, str): + raise SynapseError(400, "Invalid displayname") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -235,6 +238,9 @@ class BaseProfileHandler(BaseHandler): 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) + if not isinstance(new_avatar_url, str): + raise SynapseError(400, "Invalid displayname") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 521b6d620d..e21f8dbc58 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler): async def _handle_room_publicity_change( self, room_id, prev_event_id, event_id, typ ): - """Handle a room having potentially changed from/to world_readable/publically + """Handle a room having potentially changed from/to world_readable/publicly joinable. Args: @@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler): prev_name = prev_event.content.get("displayname") new_name = event.content.get("displayname") + # If the new name is an unexpected form, do not update the directory. + if not isinstance(new_name, str): + new_name = prev_name prev_avatar = prev_event.content.get("avatar_url") new_avatar = event.content.get("avatar_url") + # If the new avatar is an unexpected form, do not update the directory. + if not isinstance(new_avatar, str): + new_avatar = prev_avatar if prev_name != new_name or prev_avatar != new_avatar: await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 1e96ae7828..c977db042e 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -371,6 +371,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): """ Update or add a user's profile in the user directory. """ + # If the display name or avatar URL are unexpected types, overwrite them. + if not isinstance(display_name, str): + display_name = None + if not isinstance(avatar_url, str): + avatar_url = None def _update_profile_in_user_dir_txn(txn): new_entry = self.db_pool.simple_upsert_txn( From abeab964d5a0cc41fa421cf9d89dc12b7a796391 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 2 Sep 2020 11:47:26 +0100 Subject: [PATCH 18/31] Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224) this makes it a bit clearer what's going on. --- changelog.d/8224.misc | 1 + synapse/storage/databases/main/devices.py | 8 +-- .../storage/databases/main/end_to_end_keys.py | 52 +++++++++++++------ 3 files changed, 41 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8224.misc diff --git a/changelog.d/8224.misc b/changelog.d/8224.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8224.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 11956cc48e..8bedcdbdff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = stream_id if device is not None: - key_json = device.get("key_json", None) + key_json = device.key_json if key_json: result["keys"] = db_to_json(key_json) - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): + if device.signatures: + for sig_user_id, sigs in device.signatures.items(): result["keys"].setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) - device_display_name = device.get("device_display_name", None) + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 5a7de44b33..449d95f31e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -17,6 +17,7 @@ import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +import attr from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection @@ -33,6 +34,21 @@ if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem +@attr.s +class DeviceKeyLookupResult: + """The type returned by _get_e2e_device_keys_and_signatures_txn""" + + display_name = attr.ib(type=Optional[str]) + + # the key data from e2e_device_keys_json. Typically includes fields like + # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing + # key) and "signatures" (a signature of the structure by the ed25519 key) + key_json = attr.ib(type=Optional[str]) + + # cross-signing sigs + signatures = attr.ib(type=Optional[Dict], default=None) + + class EndToEndKeyWorkerStore(SQLBaseStore): async def get_e2e_device_keys_for_federation_query( self, user_id: str @@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for device_id, device in user_devices.items(): result = {"device_id": device_id} - key_json = device.get("key_json", None) + key_json = device.key_json if key_json: result["keys"] = db_to_json(key_json) - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): + if device.signatures: + for sig_user_id, sigs in device.signatures.items(): result["keys"].setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) - device_display_name = device.get("device_display_name", None) + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): - r = db_to_json(device_info.pop("key_json")) + r = db_to_json(device_info.key_json) r["unsigned"] = {} - display_name = device_info["device_display_name"] + display_name = device_info.display_name if display_name is not None: r["unsigned"]["device_display_name"] = display_name - if "signatures" in device_info: - for sig_user_id, sigs in device_info["signatures"].items(): + if device_info.signatures: + for sig_user_id, sigs in device_info.signatures.items(): r.setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) @@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): @trace def _get_e2e_device_keys_and_signatures_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ) -> Dict[str, Dict[str, Optional[Dict]]]: + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) @@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): sql = ( "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " + " d.display_name, " " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" @@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.db_pool.cursor_to_dict(txn) - result = {} - for row in rows: + result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] + for (user_id, device_id, display_name, key_json) in txn: if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row + deleted_devices.remove((user_id, device_id)) + result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( + display_name, key_json + ) if include_deleted_devices: for user_id, device_id in deleted_devices: @@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # note that target_device_result will be None for deleted devices. continue - target_device_signatures = target_device_result.setdefault("signatures", {}) + target_device_signatures = target_device_result.signatures + if target_device_signatures is None: + target_device_signatures = target_device_result.signatures = {} + signing_user_signatures = target_device_signatures.setdefault( signing_user_id, {} ) From d250521cf5171a2c815b380f184f39cd8b1963e5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Sep 2020 07:44:50 -0400 Subject: [PATCH 19/31] Convert the main methods run by the reactor to async. (#8213) --- changelog.d/8213.misc | 1 + synapse/app/admin_cmd.py | 20 +++++++++----------- synapse/app/homeserver.py | 18 ++++++++---------- 3 files changed, 18 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8213.misc diff --git a/changelog.d/8213.misc b/changelog.d/8213.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8213.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index a37818fe9a..b6c9085670 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer): pass -@defer.inlineCallbacks -def export_data_command(hs, args): +async def export_data_command(hs, args): """Export data for a user. Args: @@ -91,10 +90,8 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield defer.ensureDeferred( - hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) - ) + res = await hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) ) print(res) @@ -232,14 +229,15 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - @defer.inlineCallbacks - def run(_reactor): + async def run(): with LoggingContext("command"): - yield _base.start(ss, []) - yield args.func(ss, args) + _base.start(ss, []) + await args.func(ss, args) _base.start_worker_reactor( - "synapse-admin-cmd", config, run_command=lambda: task.react(run) + "synapse-admin-cmd", + config, + run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())), ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 98d0d14a12..6014adc850 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -411,26 +411,24 @@ def setup(config_options): return provision - @defer.inlineCallbacks - def reprovision_acme(): + async def reprovision_acme(): """ Provision a certificate from ACME, if required, and reload the TLS certificate if it's renewed. """ - reprovisioned = yield defer.ensureDeferred(do_acme()) + reprovisioned = await do_acme() if reprovisioned: _base.refresh_certificate(hs) - @defer.inlineCallbacks - def start(): + async def start(): try: # Run the ACME provisioning code, if it's enabled. if hs.config.acme_enabled: acme = hs.get_acme_handler() # Start up the webservices which we will respond to ACME # challenges with, and then provision. - yield defer.ensureDeferred(acme.start_listening()) - yield defer.ensureDeferred(do_acme()) + await acme.start_listening() + await do_acme() # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) @@ -439,8 +437,8 @@ def setup(config_options): if hs.config.oidc_enabled: oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. - yield defer.ensureDeferred(oidc.load_metadata()) - yield defer.ensureDeferred(oidc.load_jwks()) + await oidc.load_metadata() + await oidc.load_jwks() _base.start(hs, config.listeners) @@ -456,7 +454,7 @@ def setup(config_options): reactor.stop() sys.exit(1) - reactor.callWhenRunning(start) + reactor.callWhenRunning(lambda: defer.ensureDeferred(start())) return hs From 9356656e67ba6ed23f8f43d03fedaf559241aa84 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Sep 2020 07:59:39 -0400 Subject: [PATCH 20/31] Do not try to store invalid data in the stats table (#8226) --- changelog.d/8226.bugfix | 1 + synapse/storage/databases/main/stats.py | 36 ++++++++++++++++++------- 2 files changed, 28 insertions(+), 9 deletions(-) create mode 100644 changelog.d/8226.bugfix diff --git a/changelog.d/8226.bugfix b/changelog.d/8226.bugfix new file mode 100644 index 0000000000..60bdff576d --- /dev/null +++ b/changelog.d/8226.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where stats updates could break when unexpected profile data was included in events. diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9b9bc304a8..55a250ef06 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore): ) async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: - """ - Args: - room_id - fields - """ + """Update the state of a room. - # For whatever reason some of the fields may contain null bytes, which - # postgres isn't a fan of, so we replace those fields with null. + fields can contain the following keys with string values: + * join_rules + * history_visibility + * encryption + * name + * topic + * avatar + * canonical_alias + + A is_federatable key can also be included with a boolean value. + + Args: + room_id: The room ID to update the state of. + fields: The fields to update. This can include a partial list of the + above fields to only update some room information. + """ + # Ensure that the values to update are valid, they should be strings and + # not contain any null bytes. + # + # Invalid data gets overwritten with null. + # + # Note that a missing value should not be overwritten (it keeps the + # previous value). + sentinel = object() for col in ( "join_rules", "history_visibility", @@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias", ): - field = fields.get(col) - if field and "\0" in field: + field = fields.get(col, sentinel) + if field is not sentinel and (not isinstance(field, str) or "\0" in field): fields[col] = None await self.db_pool.simple_upsert( From b257c788c0541b1116b65e007f47b4f3a1de7760 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Wed, 2 Sep 2020 13:18:40 +0100 Subject: [PATCH 21/31] Add /user/{user_id}/shared_rooms/ api (#7785) * Add shared_rooms api * Add changelog * Add . * Wrap response in {"rooms": } * linting * Add unstable_features key * Remove options from isort that aren't part of 5.x `-y` and `-rc` are now default behaviour and no longer exist. `dont-skip` is no longer required https://timothycrosley.github.io/isort/CHANGELOG/#500-penny-july-4-2020 * Update imports to make isort happy * Add changelog * Update tox.ini file with correct invocation * fix linting again for isort * Vendor prefix unstable API * Fix to match spec * import Codes * import Codes * Use FORBIDDEN * Update changelog.d/7785.feature Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Implement get_shared_rooms_for_users * a comma * trailing whitespace * Handle the easy feedback * Switch to using runInteraction * Add tests * Feedback * Seperate unstable endpoint from v2 * Add upgrade node * a line * Fix style by adding a blank line at EOF. * Update synapse/storage/databases/main/user_directory.py Co-authored-by: Tulir Asokan * Update synapse/storage/databases/main/user_directory.py Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Update UPGRADE.rst Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Fix UPGRADE/CHANGELOG unstable paths unstable unstable unstable Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Tulir Asokan Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Patrick Cloke Co-authored-by: Tulir Asokan --- UPGRADE.rst | 13 ++ changelog.d/7785.feature | 1 + docs/workers.md | 1 + synapse/rest/__init__.py | 4 + synapse/rest/client/v2_alpha/shared_rooms.py | 68 +++++++++ synapse/rest/client/versions.py | 2 + .../storage/databases/main/user_directory.py | 44 +++++- .../rest/client/v2_alpha/test_shared_rooms.py | 138 ++++++++++++++++++ 8 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7785.feature create mode 100644 synapse/rest/client/v2_alpha/shared_rooms.py create mode 100644 tests/rest/client/v2_alpha/test_shared_rooms.py diff --git a/UPGRADE.rst b/UPGRADE.rst index 6492fa011f..77be1b2952 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -1,3 +1,16 @@ +Upgrading to v1.20.0 +==================== + +Shared rooms endpoint (MSC2666) +------------------------------- + +This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*` +for fetching rooms one user has in common with another. This feature requires the +`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir` +worker, requests to this endpoint must be handled by that worker. +See `docs/workers.md `_ for more details. + + Upgrading Synapse ================= diff --git a/changelog.d/7785.feature b/changelog.d/7785.feature new file mode 100644 index 0000000000..c7e51c9320 --- /dev/null +++ b/changelog.d/7785.feature @@ -0,0 +1 @@ +Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666). diff --git a/docs/workers.md b/docs/workers.md index bfec745897..7a8f5c89fc 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching the following regular expressions: ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$ + ^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$ When using this worker you must also set `update_user_directory: False` in the shared configuration file to stop the main synapse running background diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 46e458e95b..87f927890c 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import ( room_keys, room_upgrade_rest_servlet, sendtodevice, + shared_rooms, sync, tags, thirdparty, @@ -125,3 +126,6 @@ class ClientRestResource(JsonResource): synapse.rest.admin.register_servlets_for_client_rest_resource( hs, client_resource ) + + # unstable + shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py new file mode 100644 index 0000000000..2492634dac --- /dev/null +++ b/synapse/rest/client/v2_alpha/shared_rooms.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super(UserSharedRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0d668df0b6..24ac57f35d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, + # Implements additional endpoints as described in MSC2666 + "uk.half-shot.msc2666": True, }, }, ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index c977db042e..f2f9a5799a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -675,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) + @cached() + async def get_shared_rooms_for_users( + self, user_id: str, other_user_id: str + ) -> Set[str]: + """ + Returns the rooms that a local user shares with another local or remote user. + + Args: + user_id: The MXID of a local user + other_user_id: The MXID of the other user + + Returns: + A set of room ID's that the users share. + """ + + def _get_shared_rooms_for_users_txn(txn): + txn.execute( + """ + SELECT p1.room_id + FROM users_in_public_rooms as p1 + INNER JOIN users_in_public_rooms as p2 + ON p1.room_id = p2.room_id + AND p1.user_id = ? + AND p2.user_id = ? + UNION + SELECT room_id + FROM users_who_share_private_rooms + WHERE + user_id = ? + AND other_user_id = ? + """, + (user_id, other_user_id, user_id, other_user_id), + ) + rows = self.db_pool.cursor_to_dict(txn) + return rows + + rows = await self.db_pool.runInteraction( + "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + ) + + return {row["room_id"] for row in rows} + async def get_user_directory_stream_pos(self) -> int: return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py new file mode 100644 index 0000000000..5ae72fd008 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import shared_rooms + +from tests import unittest + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user): + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + self.render(request) + return request, channel + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) + self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) + self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) + self.helper.join(room_public, user=u2, tok=u2_token) + self.helper.join(room_private, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertTrue(room_public in channel.json_body["joined"]) + self.assertTrue(room_private in channel.json_body["joined"]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) From 82c1ee1c22a87b9e6e3179947014b0f11c0a1ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 15:48:37 +0100 Subject: [PATCH 22/31] Add experimental support for sharding event persister. (#8170) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8170.feature | 1 + synapse/config/_base.py | 21 +++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++--- synapse/handlers/federation.py | 44 +++++++++---- synapse/handlers/message.py | 14 ++-- synapse/handlers/room.py | 14 ++-- synapse/handlers/room_member.py | 7 -- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- .../databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 4 +- .../storage/databases/main/events_worker.py | 66 +++++++++++++------ .../delta/58/14events_instance_name.sql | 16 +++++ .../58/14events_instance_name.sql.postgres | 26 ++++++++ synapse/storage/util/id_generators.py | 10 +-- 18 files changed, 206 insertions(+), 77 deletions(-) create mode 100644 changelog.d/8170.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8170.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1417487427..73f0717b0d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 16389a0dca..bd8efbb768 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -923,7 +923,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72bb638167..6398774c02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -904,9 +903,10 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -974,7 +974,9 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9d5b1828df..55794c3057 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a7962b0ada..e5cb3c2d5c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 16c63ff4ec..3705618b4f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 0ac854aee2..c73d54fb67 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6313b41eef..46b11e705b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -97,6 +97,7 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -108,7 +109,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -800,6 +801,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9f3d23f0a5..8fd21c2bf8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -231,8 +231,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -362,9 +366,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. From 0d4f614fdae1221bd1833d9cdeb5adc3850d3af9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 2 Sep 2020 15:53:26 +0100 Subject: [PATCH 23/31] Refactor `_get_e2e_device_keys_for_federation_query_txn` (#8225) We can use the existing `_get_e2e_device_keys_and_signatures_txn` instead of creating our own txn function --- changelog.d/8225.misc | 1 + .../storage/databases/main/end_to_end_keys.py | 17 ++++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8225.misc diff --git a/changelog.d/8225.misc b/changelog.d/8225.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8225.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 449d95f31e..4059701cfd 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -58,18 +58,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return await self.db_pool.runInteraction( - "get_e2e_device_keys_for_federation_query", - self._get_e2e_device_keys_for_federation_query_txn, - user_id, - ) - - def _get_e2e_device_keys_for_federation_query_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: now_stream_id = self.get_device_stream_token() - devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)]) + devices = await self.db_pool.runInteraction( + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, + [(user_id, None)], + ) if devices: user_devices = devices[user_id] From 5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 2 Sep 2020 17:19:37 +0100 Subject: [PATCH 24/31] Re-implement unread counts (again) (#8059) --- changelog.d/8059.feature | 1 + synapse/handlers/sync.py | 33 +-- synapse/push/bulk_push_rule_evaluator.py | 92 ++++--- synapse/push/push_tools.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 1 + .../databases/main/event_push_actions.py | 224 +++++++++++++----- synapse/storage/databases/main/events.py | 4 +- .../main/schema/delta/58/15unread_count.sql | 26 ++ synapse/storage/databases/main/stream.py | 21 +- .../replication/slave/storage/test_events.py | 10 +- tests/rest/client/v2_alpha/test_sync.py | 157 +++++++++++- tests/storage/test_event_push_actions.py | 8 +- 12 files changed, 457 insertions(+), 122 deletions(-) create mode 100644 changelog.d/8059.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/15unread_count.sql diff --git a/changelog.d/8059.feature b/changelog.d/8059.feature new file mode 100644 index 0000000000..feb02be234 --- /dev/null +++ b/changelog.d/8059.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9a86eb01c9..43fc40fc2f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -95,7 +95,12 @@ class TimelineBatch: __bool__ = __nonzero__ # python3 -@attr.s(slots=True, frozen=True) +# We can't freeze this class, because we need to update it after it's instantiated to +# update its unread count. This is because we calculate the unread count for a room only +# if there are updates for it, which we check after the instance has been created. +# This should not be a big deal because we update the notification counts afterwards as +# well anyway. +@attr.s(slots=True) class JoinedSyncResult: room_id = attr.ib(type=str) timeline = attr.ib(type=TimelineBatch) @@ -104,6 +109,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -931,7 +937,7 @@ class SyncHandler(object): async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> Optional[Dict[str, str]]: + ) -> Dict[str, int]: with Measure(self.clock, "unread_notifs_for_room_id"): last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), @@ -939,15 +945,10 @@ class SyncHandler(object): receipt_type="m.read", ) - if last_unread_event_id: - notifs = await self.store.get_unread_event_push_actions_by_room_for_user( - room_id, sync_config.user.to_string(), last_unread_event_id - ) - return notifs - - # There is no new information in this period, so your notification - # count is whatever it was last time. - return None + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( + room_id, sync_config.user.to_string(), last_unread_event_id + ) + return notifs async def generate_sync_result( self, @@ -1886,7 +1887,7 @@ class SyncHandler(object): ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, str] + unread_notifications = {} # type: Dict[str, int] room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1895,14 +1896,16 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=0, ) if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + room_sync.unread_count = notifs["unread_count"] sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e7fcee0e87..e7fa02b78b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,10 @@ from collections import namedtuple from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.event_auth import get_user_power_level +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache @@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache( ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + + class BulkPushRuleEvaluator(object): """Calculates the outcome of push rules for an event for all users in the room at once. @@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level async def action_for_event_by_user(self, event, context) -> None: - """Given an event and context, evaluate the push rules and insert the - results into the event_push_actions_staging table. + """Given an event and context, evaluate the push rules, check if the message + should increment the unread count, and insert the results into the + event_push_actions_staging table. """ + count_as_unread = _should_count_as_unread(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} @@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object): if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) + actions_by_user[uid] = [] + for rule in rules: if "enabled" in rule and not rule["enabled"]: continue @@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, count_as_unread, + ) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -369,8 +420,8 @@ class RulesForRoom(object): Args: ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets updated with any new rules. - member_event_ids (list): List of event ids for membership events that - have happened since the last time we filled rules_by_user + member_event_ids (dict): Dict of user id to event id for membership events + that have happened since the last time we filled rules_by_user state_group: The state group we are currently computing push rules for. Used when updating the cache. """ @@ -390,34 +441,19 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = { + user_ids = { user_id for user_id, membership in members.values() if membership == Membership.JOIN } - logger.debug("Joined: %r", interested_in_user_ids) + logger.debug("Joined: %r", user_ids) - if_users_with_pushers = await self.store.get_if_users_have_pushers( - interested_in_user_ids, on_invalidate=self.invalidate_all_cb - ) - - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - logger.debug("With pushers: %r", user_ids) - - users_with_receipts = await self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb - ) - - logger.debug("With receipts: %r", users_with_receipts) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in interested_in_user_ids: - user_ids.add(uid) + # Previously we only considered users with pushers or read receipts in that + # room. We can't do this anymore because we use push actions to calculate unread + # counts, which don't rely on the user having pushers or sent a read receipt into + # the room. Therefore we just need to filter for local users here. + user_ids = list(filter(self.is_mine_id, user_ids)) rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..f7a25571f3 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -36,7 +36,7 @@ async def get_badge_count(store, user_id): ) # return one badge count per conversation, as count per # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + badge += 1 if notifs["unread_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 96488b131a..a0b00135e1 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 384c39f4d7..001d06378d 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,7 +15,9 @@ # limitations under the License. import logging -from typing import Dict, List, Union +from typing import Dict, List, Optional, Tuple, Union + +import attr from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): + self, room_id: str, user_id: str, last_read_event_id: Optional[str], + ) -> Dict[str, int]: + """Get the notification count, the highlight count and the unread message count + for a given user in a given room after the given read receipt. + + Note that this function assumes the user to be a current member of the room, + since it's either called by the sync handler to handle joined room entries, or by + the HTTP pusher to calculate the badge of unread joined rooms. + + Args: + room_id: The room to retrieve the counts in. + user_id: The user to retrieve the counts for. + last_read_event_id: The event associated with the latest read receipt for + this user in this room. None if no receipt for this user in this room. + + Returns + A dict containing the counts mentioned earlier in this docstring, + respectively under the keys "notify_count", "highlight_count" and + "unread_count". + """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, @@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id + self, txn, room_id, user_id, last_read_event_id, ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} + stream_ordering = None - stream_ordering = results[0][0] + if last_read_event_id is not None: + stream_ordering = self.get_stream_id_for_event_txn( + txn, last_read_event_id, allow_none=True, + ) + + if stream_ordering is None: + # Either last_read_event_id is None, or it's an event we don't have (e.g. + # because it's been purged), in which case retrieve the stream ordering for + # the latest membership event from this user in this room (which we assume is + # a join). + event_id = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + retcol="event_id", + ) + + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 sql = ( - "SELECT count(*)" + "SELECT" + " COUNT(CASE WHEN notif = 1 THEN 1 END)," + " COUNT(CASE WHEN highlight = 1 THEN 1 END)," + " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" + " WHERE user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" ) txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - notify_count = row[0] if row else 0 + + (notif_count, highlight_count, unread_count) = (0, 0, 0) + + if row: + (notif_count, highlight_count, unread_count) = row txn.execute( """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, + SELECT notif_count, unread_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering), ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] - - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) - - txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - highlight_count = row[0] if row else 0 - return {"notify_count": notify_count, "highlight_count": highlight_count} + if row: + notif_count += row[0] + unread_count += row[1] + + return { + "notify_count": notif_count, + "unread_count": unread_count, + "highlight_count": highlight_count, + } async def get_push_action_users_in_range( self, min_stream_ordering, max_stream_ordering @@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? + WHERE user_id = ? AND stream_ordering > ? AND notif = 1 LIMIT 1 """ @@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) async def add_push_actions_to_staging( - self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] + self, + event_id: str, + user_id_actions: Dict[str, List[Union[dict, str]]], + count_as_unread: bool, ) -> None: """Add the push actions for the event to the push action staging area. @@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore): event_id user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. + count_as_unread: Whether this event should increment unread counts. """ - if not user_id_actions: return # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. + # can be used to insert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 + notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - 1, # notif column + notif, # notif column is_highlight, # highlight column + int(count_as_unread), # unread column ) def _add_push_actions_to_staging_txn(txn): @@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql = """ INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) + (event_id, user_id, actions, notif, highlight, unread) + VALUES (?, ?, ?, ?, ?, ?) """ txn.executemany( @@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, + coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as notif_count, + SELECT user_id, room_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 + AND %s = 1 GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() + # First get the count of unread messages. + txn.execute( + sql % ("unread_count", "unread"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) - logger.info("Rotating notifications, handling %d rows", len(rows)) + # We need to merge results from the two requests (the one that retrieves the + # unread count and the one that retrieves the notifications count) into a single + # object because we might not have the same amount of rows in each of them. To do + # this, we use a dict indexed on the user ID and room ID to make it easier to + # populate. + summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + for row in txn: + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=row[2], + stream_ordering=row[3], + old_user_id=row[4], + notif_count=0, + ) + + # Then get the count of notifications. + txn.execute( + sql % ("notif_count", "notif"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + for row in txn: + if (row[0], row[1]) in summaries: + summaries[(row[0], row[1])].notif_count = row[2] + else: + # Because the rules on notifying are different than the rules on marking + # a message unread, we might end up with messages that notify but aren't + # marked unread, so we might not have a summary for this (user, room) + # tuple to complete. + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=0, + stream_ordering=row[3], + old_user_id=row[4], + notif_count=row[2], + ) + + logger.info("Rotating notifications, handling %d rows", len(summaries)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the @@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_summary", values=[ { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], + "user_id": user_id, + "room_id": room_id, + "notif_count": summary.notif_count, + "unread_count": summary.unread_count, + "stream_ordering": summary.stream_ordering, } - for row in rows - if row[4] is None + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is None ], ) txn.executemany( """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + UPDATE event_push_summary + SET notif_count = ?, unread_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ( + ( + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + user_id, + room_id, + ) + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is not None + ), ) txn.execute( @@ -879,3 +961,15 @@ def _action_has_highlight(actions): pass return False + + +@attr.s +class _EventPushSummary: + """Summary of pending event push actions for a given user in a given room. + Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. + """ + + unread_count = attr.ib(type=int) + stream_ordering = attr.ib(type=int) + old_user_id = attr.ib(type=str) + notif_count = attr.ib(type=int) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 46b11e705b..b94fe7ac17 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1298,9 +1298,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight + topological_ordering, notif, highlight, unread ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql new file mode 100644 index 0000000000..b451e8663a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We're hijacking the push actions to store unread messages and unread counts (specified +-- in MSC2654) because doing otherwise would result in either performance issues or +-- reimplementing a consequent bit of the push actions. + +-- Add columns to event_push_actions and event_push_actions_staging to track unread +-- messages and calculate unread counts. +ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0; +ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0; + +-- Add column to event_push_summary +ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0; \ No newline at end of file diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 24f44a7e36..83c1ddf95a 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -47,7 +47,11 @@ from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken @@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A stream ID. """ - return await self.db_pool.simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" + return await self.db_pool.runInteraction( + "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, + ) + + def get_stream_id_for_event_txn( + self, txn: LoggingTransaction, event_id: str, allow_none=False, + ) -> int: + return self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + allow_none=allow_none, ) async def get_stream_token_for_event(self, event_id: str) -> str: diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 0b5204654c..561258a356 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 1}, ) self.persist( @@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2}, + {"highlight_count": 1, "unread_count": 0, "notify_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): @@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.get_success( self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + event.event_id, + {user_id: actions for user_id, actions in push_actions}, + False, ) ) return event, context diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index cdfd2634aa..c0595963dd 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) self.assertEquals( counts, - {"notify_count": noitf_count, "highlight_count": highlight_count}, + { + "notify_count": noitf_count, + "unread_count": 0, # Unread counts are tested in the sync tests. + "highlight_count": highlight_count, + }, ) @defer.inlineCallbacks @@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield defer.ensureDeferred( self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + event.event_id, {user_id: action}, False, ) ) yield defer.ensureDeferred( From 112266eafd457204a34a76fa51d7074d0809a1db Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 17:52:38 +0100 Subject: [PATCH 25/31] Add StreamStore to mypy (#8232) --- changelog.d/8232.misc | 1 + mypy.ini | 1 + synapse/events/__init__.py | 4 +-- synapse/storage/database.py | 34 ++++++++++++++++++ synapse/storage/databases/main/stream.py | 46 ++++++++++++++---------- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8232.misc diff --git a/changelog.d/8232.misc b/changelog.d/8232.misc new file mode 100644 index 0000000000..3a7a352c4f --- /dev/null +++ b/changelog.d/8232.misc @@ -0,0 +1 @@ +Add type hints to `StreamStore`. diff --git a/mypy.ini b/mypy.ini index 21c6f523a0..ae3290d5bb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -43,6 +43,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 67db763dbf..62ea44fa49 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ import abc import os from distutils.util import strtobool -from typing import Dict, Optional, Type +from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -120,7 +120,7 @@ class _EventInternalMetadata(object): # be here before = DictProperty("before") # type: str after = DictProperty("after") # type: str - order = DictProperty("order") # type: int + order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: return dict(self._dict) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7ab370efef..af8796ad92 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -604,6 +604,18 @@ class DatabasePool(object): results = [dict(zip(col_headers, row)) for row in cursor] return results + @overload + async def execute( + self, desc: str, decoder: Literal[None], query: str, *args: Any + ) -> List[Tuple[Any, ...]]: + ... + + @overload + async def execute( + self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any + ) -> R: + ... + async def execute( self, desc: str, @@ -1088,6 +1100,28 @@ class DatabasePool(object): desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one_onecol", + ) -> Any: + ... + + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one_onecol", + ) -> Optional[Any]: + ... + async def simple_select_one_onecol( self, table: str, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 83c1ddf95a..be6df8a6d1 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,7 +39,7 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from twisted.internet import defer @@ -54,9 +54,12 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import RoomStreamToken +from synapse.types import Collection, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -206,7 +209,7 @@ def _make_generic_sql_bound( ) -def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: +def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super(StreamWorkerStore, self).__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() async def get_room_events_stream_for_rooms( self, - room_ids: Iterable[str], + room_ids: Collection[str], from_key: str, to_key: str, limit: int = 0, @@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def get_rooms_that_changed(self, room_ids, from_key): + def get_rooms_that_changed( + self, room_ids: Collection[str], from_key: str + ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. Args: - room_ids (list) - from_key (str): The room_key portion of a StreamToken + room_ids + from_key: The room_key portion of a StreamToken """ - from_key = RoomStreamToken.parse_stream_token(from_key).stream + from_id = RoomStreamToken.parse_stream_token(from_key).stream return { room_id for room_id in room_ids - if self._events_stream_cache.has_entity_changed(room_id, from_key) + if self._events_stream_cache.has_entity_changed(room_id, from_id) } async def get_room_events_stream_for_room( @@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - async def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user( + self, user_id: str, from_key: str, to_key: str + ) -> List[EventBase]: from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return row[0][0] if row else 0 - def _get_max_topological_txn(self, txn, room_id): + def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), @@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_events_around_txn( self, - txn, + txn: LoggingTransaction, room_id: str, event_id: str, before_limit: int, @@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=["stream_ordering", "topological_ordering"], ) + # This cannot happen as `allow_none=False`. + assert results is not None + # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( @@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn) -> None: + def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): GROUP BY type """ txn.execute(sql) - min_positions = dict(txn) # Map from type -> min position + min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} @@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _paginate_room_events_txn( self, - txn, + txn: LoggingTransaction, room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, From 912e024913bdb2b0360f5d2631a5345d3424d2e2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Sep 2020 13:11:02 -0400 Subject: [PATCH 26/31] Convert runInteraction to async/await (#8156) --- changelog.d/8156.misc | 1 + synapse/storage/database.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8156.misc diff --git a/changelog.d/8156.misc b/changelog.d/8156.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8156.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index af8796ad92..8851710d47 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -28,6 +28,7 @@ from typing import ( Optional, Tuple, TypeVar, + cast, overload, ) @@ -35,7 +36,6 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi -from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -507,8 +507,9 @@ class DatabasePool(object): self._txn_perf_counters.update(desc, duration) sql_txn_timer.labels(desc).observe(duration) - @defer.inlineCallbacks - def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): + async def runInteraction( + self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + ) -> R: """Starts a transaction on the database and runs a given function Arguments: @@ -521,7 +522,7 @@ class DatabasePool(object): kwargs: named args to pass to `func` Returns: - Deferred: The result of func + The result of func """ after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] @@ -530,16 +531,14 @@ class DatabasePool(object): logger.warning("Starting db txn '%s' from sentinel context", desc) try: - result = yield defer.ensureDeferred( - self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs ) for after_callback, after_args, after_kwargs in after_callbacks: @@ -549,7 +548,7 @@ class DatabasePool(object): after_callback(*after_args, **after_kwargs) raise - return result + return cast(R, result) async def runWithConnection( self, func: "Callable[..., R]", *args: Any, **kwargs: Any From c8758cb72fccc5594ce8da7ccd2256315a8aa27e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Sep 2020 15:03:12 -0400 Subject: [PATCH 27/31] Add an overload for simple_select_one_onecol_txn. (#8235) --- changelog.d/8235.misc | 1 + synapse/storage/database.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 changelog.d/8235.misc diff --git a/changelog.d/8235.misc b/changelog.d/8235.misc new file mode 100644 index 0000000000..3a7a352c4f --- /dev/null +++ b/changelog.d/8235.misc @@ -0,0 +1 @@ +Add type hints to `StreamStore`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 8851710d47..78ca6d8346 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1149,6 +1149,30 @@ class DatabasePool(object): allow_none=allow_none, ) + @overload + @classmethod + def simple_select_one_onecol_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[False] = False, + ) -> Any: + ... + + @overload + @classmethod + def simple_select_one_onecol_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: Literal[True] = True, + ) -> Optional[Any]: + ... + @classmethod def simple_select_one_onecol_txn( cls, From 6f6f371a8731353dc1c3db1f20fc392f8b4e780d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Sep 2020 11:50:49 +0100 Subject: [PATCH 28/31] wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method (#8231) We have three things which all call `_get_e2e_device_keys_and_signatures_txn` with their own `runInteraction`. Factor out the common code. --- changelog.d/8231.misc | 1 + synapse/storage/databases/main/devices.py | 4 +- .../storage/databases/main/end_to_end_keys.py | 52 ++++++++++++++----- 3 files changed, 40 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8231.misc diff --git a/changelog.d/8231.misc b/changelog.d/8231.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8231.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8bedcdbdff..f8fe948122 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore): List of objects representing an device update EDU """ devices = ( - await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, + await self.get_e2e_device_keys_and_signatures( query_map.keys(), include_all_devices=True, include_deleted_devices=True, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4059701cfd..cc0b15ae07 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: @attr.s class DeviceKeyLookupResult: - """The type returned by _get_e2e_device_keys_and_signatures_txn""" + """The type returned by get_e2e_device_keys_and_signatures""" display_name = attr.ib(type=Optional[str]) @@ -60,11 +60,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """ now_stream_id = self.get_device_stream_token() - devices = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - [(user_id, None)], - ) + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: user_devices = devices[user_id] @@ -108,11 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = await self.db_pool.runInteraction( - "get_e2e_device_keys_and_signatures_txn", - self._get_e2e_device_keys_and_signatures_txn, - query_list, - ) + results = await self.get_e2e_device_keys_and_signatures(query_list) # Build the result structure, un-jsonify the results, and add the # "unsigned" section @@ -135,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return rv @trace - def _get_e2e_device_keys_and_signatures_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + async def get_e2e_device_keys_and_signatures( + self, + query_list: List[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Fetch a list of device keys, together with their cross-signatures. + + Args: + query_list: List of pairs of user_ids and device_ids. Device id can be None + to indicate "all devices for this user" + + include_all_devices: whether to return devices without device keys + + include_deleted_devices: whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. + """ set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) + result = await self.db_pool.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_and_signatures_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + log_kv(result) + return result + + def _get_e2e_device_keys_and_signatures_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: query_clauses = [] query_params = [] signature_query_clauses = [] @@ -230,7 +255,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) signing_user_signatures[signing_key_id] = signature - log_kv(result) return result async def get_e2e_one_time_keys( From 5bfc79486dbc9f51812a8d901bb1bb9fa2b98a81 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 3 Sep 2020 12:54:10 +0100 Subject: [PATCH 29/31] Fix typing for SyncHandler (#8237) --- changelog.d/8237.misc | 1 + synapse/handlers/sync.py | 12 +++++++----- synapse/storage/databases/main/roommember.py | 6 +++--- synapse/storage/databases/main/tags.py | 4 ++-- 4 files changed, 13 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8237.misc diff --git a/changelog.d/8237.misc b/changelog.d/8237.misc new file mode 100644 index 0000000000..29d946cde6 --- /dev/null +++ b/changelog.d/8237.misc @@ -0,0 +1 @@ +Fix type hints in `SyncHandler`. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 43fc40fc2f..8728403e62 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.metrics import Measure, measure_func from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Debug logger for https://github.com/matrix-org/synapse/issues/4422 @@ -244,7 +247,7 @@ class SyncResult: class SyncHandler(object): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs_config = hs.config self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -717,9 +720,8 @@ class SyncHandler(object): ] missing_hero_state = await self.store.get_events(missing_hero_event_ids) - missing_hero_state = missing_hero_state.values() - for s in missing_hero_state: + for s in missing_hero_state.values(): cache.set(s.state_key, s.event_id) state[(EventTypes.Member, s.state_key)] = s @@ -1771,7 +1773,7 @@ class SyncHandler(object): ignored_users: Set[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[List[JsonDict]], + tags: Optional[Dict[str, Dict[str, Any]]], account_data: Dict[str, JsonDict], always_include: bool = False, ): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c4ce5c17c2..c46f5cd524 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -298,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: List[str] - ) -> Optional[List[RoomsForUser]]: + self, user_id: str, membership_list: Collection[str] + ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -314,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): The RoomsForUser that the user matches the membership types. """ if not membership_list: - return None + return [] rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 0c34bbf21a..96ffe26cc9 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore): "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - tags_by_room = {} + tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]] for row in rows: room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags[row["tag"]] = db_to_json(row["content"]) @@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Dict[str, List[str]]: + ) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version From 2aa127c20701320c5627b82d9fc71e84e02fd114 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Sep 2020 09:45:36 -0400 Subject: [PATCH 30/31] Revert pinning of setuptools (#8239) --- INSTALL.md | 2 +- changelog.d/8212.bugfix | 1 - changelog.d/8239.misc | 1 + synapse/python_dependencies.py | 4 ---- 4 files changed, 2 insertions(+), 6 deletions(-) delete mode 100644 changelog.d/8212.bugfix create mode 100644 changelog.d/8239.misc diff --git a/INSTALL.md b/INSTALL.md index bdb7769fe9..22f7b7c029 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -73,7 +73,7 @@ mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip -pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions +pip install --upgrade setuptools pip install matrix-synapse ``` diff --git a/changelog.d/8212.bugfix b/changelog.d/8212.bugfix deleted file mode 100644 index 0f8c0aed92..0000000000 --- a/changelog.d/8212.bugfix +++ /dev/null @@ -1 +0,0 @@ -Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions. diff --git a/changelog.d/8239.misc b/changelog.d/8239.misc new file mode 100644 index 0000000000..88a3603e61 --- /dev/null +++ b/changelog.d/8239.misc @@ -0,0 +1 @@ +Revert pinning of setuptools. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index d666f22674..2d995ec456 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -74,10 +74,6 @@ REQUIREMENTS = [ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", - # setuptools is required by a variety of dependencies, unfortunately version - # 50.0 is incompatible with older Python versions, see - # https://github.com/pypa/setuptools/issues/2352 - "setuptools!=50.0", ] CONDITIONAL_REQUIREMENTS = { From 15c35c250c9e9aa5c0c7334d26344c76ba379a42 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Sep 2020 09:47:41 -0400 Subject: [PATCH 31/31] Remove useless changelog about reverting a #8239. --- changelog.d/8239.misc | 1 - 1 file changed, 1 deletion(-) delete mode 100644 changelog.d/8239.misc diff --git a/changelog.d/8239.misc b/changelog.d/8239.misc deleted file mode 100644 index 88a3603e61..0000000000 --- a/changelog.d/8239.misc +++ /dev/null @@ -1 +0,0 @@ -Revert pinning of setuptools.