Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197)

pull/8207/head
Patrick Cloke 2020-08-28 11:34:50 -04:00 committed by GitHub
parent 22b926c284
commit d58fda99ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 160 deletions

1
changelog.d/8197.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import List
from typing import Dict, List, Union
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
async def get_if_maybe_push_in_range_for_user(
self, user_id: str, min_stream_ordering: int
) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
user_id (str)
min_stream_ordering (int)
user_id
min_stream_ordering
Returns:
Deferred[bool]: True if there may be push to process, False if
there definitely isn't.
True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
async def add_push_actions_to_staging(self, event_id, user_id_actions):
async def add_push_actions_to_staging(
self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
event_id (str)
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
user_id to list of push actions, where an action can either be
a string or dict.
Returns:
Deferred
event_id
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
"""
if not user_id_actions:
@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
def find_first_stream_ordering_after_ts(self, ts):
async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
ts (int): timestamp in millis
ts: timestamp in millis
Returns:
Deferred[int]: stream ordering of the first event received on/after
the timestamp
stream ordering of the first event received on/after the timestamp
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,

View File

@ -17,7 +17,7 @@
import logging
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
def get_user_by_access_token(self, token):
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
token: The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
def set_server_admin(self, user, admin):
async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
admin (bool): true iff the user is to be a server admin,
false otherwise.
user: user ID of the user to test
admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Returns:
A mapping of user_id -> password_hash.
"""
def f(txn):
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
):
async def get_threepid_validation_session(
self,
medium: Optional[str],
client_secret: str,
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
medium (str|None): The medium of the 3PID
address (str|None): The address of the 3PID
sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to help identify this
medium: The medium of the 3PID
client_secret: A unique string provided by the client to help identify this
validation attempt
validated (bool|None): Whether sessions should be filtered by
address: The address of the 3PID
sid: The ID of the validation session
validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
Deferred[dict|None]: A dict containing the following:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
def delete_threepid_session(self, session_id):
async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
session_id (str): The ID of the session to delete
session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
def register_user(
async def register_user(
self,
user_id,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
shadow_banned=False,
):
user_id: str,
password_hash: Optional[str] = None,
was_guest: bool = False,
make_guest: bool = False,
appservice_id: Optional[str] = None,
create_profile_with_displayname: Optional[str] = None,
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
) -> None:
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
create_profile_with_displayname (unicode): Optionally create a profile for
user_id: The desired user ID to register.
password_hash: Optional. The password hash for this user.
was_guest: Whether this is a guest account being upgraded to a
non-guest account.
make_guest: True if the the new user should be guest, false to add a
regular user account.
appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned,
i.e. they may be told their requests succeeded but we ignore them.
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
Returns:
Deferred
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
def user_set_password_hash(self, user_id, password_hash):
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
def user_set_consent_version(self, user_id, consent_version):
async def user_set_consent_version(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy the user has consented
to
user_id: full mxid of the user to update
consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_version", f)
await self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
async def user_set_consent_server_notice_sent(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy we have notified the
user about
user_id: full mxid of the user to update
consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
user_id: ID of user the tokens belong to
except_token_id: access_tokens ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
return self.db_pool.runInteraction("user_delete_access_tokens", f)
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
return self.db_pool.runInteraction("delete_access_token", f)
await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
session_id (str): The id of a validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
token (str): A validation token
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
session_id: The id of a validation session
client_secret: A unique string provided by the client to help identify
this validation attempt
token: A validation token
current_ts: The current unix time in milliseconds. Used for checking
token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
def start_or_continue_validation_session(
async def start_or_continue_validation_session(
self,
medium,
address,
session_id,
client_secret,
send_attempt,
next_link,
token,
token_expires,
):
medium: str,
address: str,
session_id: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str],
token: str,
token_expires: int,
) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
session_id (str): The id of this validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
next_link (str|None): The link to redirect the user to upon
successful validation
token (str): The validation token
token_expires (int): The timestamp for which after the token
will no longer be valid
medium: The medium of the 3PID
address: The address of the 3PID
session_id: The id of this validation session
client_secret: A unique string provided by the client to help
identify this validation attempt
send_attempt: The latest send_attempt on this session
next_link: The link to redirect the user to upon successful validation
token: The validation token
token_expires: The timestamp for which after the token will no
longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
def cull_expired_threepid_validation_tokens(self):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
return txn.execute(sql, (ts,))
txn.execute(sql, (ts,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),

View File

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id: str):
return self.db_pool.runInteraction(
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id: str):
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
return await self.db_pool.runInteraction(
"get_room_summary", _get_room_summary_txn
)
@cached()
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
A awaitable list of RoomsForUser.
A list of RoomsForUser.
"""
return self.get_rooms_for_local_user_where_membership_is(
return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
Returns the rooms the user is in currently, along with the stream
ordering of the most recent join for that user and room.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
def _get_rooms_for_user_with_stream_ordering_txn(
self, txn, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
return results
return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
def get_forgotten_rooms_for_user(self, user_id: str):
async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
user_id
user_id: The user ID to query the rooms of.
Returns:
Deferred[set[str]]
The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id: str, room_id: str):
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.db_pool.runInteraction("forget_membership", f)
await self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):