Converts event_federation and registration databases to async/await (#8061)
parent
61d8ff0d44
commit
a0acdfa9e9
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -15,9 +15,7 @@
|
|||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
|
||||
return dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_depth_of(self, event_ids):
|
||||
async def get_max_depth_of(self, event_ids: List[str]) -> int:
|
||||
"""Returns the max depth of a set of event IDs
|
||||
|
||||
Args:
|
||||
event_ids (list[str])
|
||||
|
||||
Returns
|
||||
Deferred[int]
|
||||
event_ids: The event IDs to calculate the max depth of.
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
|
||||
return event_results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
||||
ids = yield self.db_pool.runInteraction(
|
||||
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
||||
ids = await self.db_pool.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
room_id,
|
||||
|
@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
latest_events,
|
||||
limit,
|
||||
)
|
||||
events = yield self.get_events_as_list(ids)
|
||||
events = await self.get_events_as_list(ids)
|
||||
return events
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
|
@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
event_results.reverse()
|
||||
return event_results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_successor_events(self, event_ids):
|
||||
async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
|
||||
"""Fetch all events that have the given events as a prev event
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str])
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]
|
||||
event_ids: The events to use as the previous events.
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_edges",
|
||||
column="prev_event_id",
|
||||
iterable=event_ids,
|
||||
|
@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
|
|||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_delete_non_state_event_auth(self, progress, batch_size):
|
||||
async def _background_delete_non_state_event_auth(self, progress, batch_size):
|
||||
def delete_event_auth(txn):
|
||||
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
||||
max_stream_id = progress.get("max_stream_id_exclusive")
|
||||
|
@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
|
|||
|
||||
return min_stream_id >= target_min_stream_id
|
||||
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_AUTH_STATE_ONLY
|
||||
)
|
||||
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
|
@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
|
|||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
||||
|
||||
|
@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="get_user_by_id",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_trial_user(self, user_id):
|
||||
async def is_trial_user(self, user_id: str) -> bool:
|
||||
"""Checks if user is in the "trial" period, i.e. within the first
|
||||
N days of registration defined by `mau_trial_days` config
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[bool]
|
||||
user_id: The user to check for trial status.
|
||||
"""
|
||||
|
||||
info = yield self.get_user_by_id(user_id)
|
||||
info = await self.get_user_by_id(user_id)
|
||||
if not info:
|
||||
return False
|
||||
|
||||
|
@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
"get_user_by_access_token", self._query_for_auth, token
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_expiration_ts_for_user(self, user_id):
|
||||
@cached()
|
||||
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
|
||||
"""Get the expiration timestamp for the account bearing a given user ID.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user.
|
||||
user_id: The ID of the user.
|
||||
Returns:
|
||||
defer.Deferred: None, if the account has no expiration timestamp,
|
||||
otherwise int representation of the timestamp (as a number of
|
||||
milliseconds since epoch).
|
||||
None, if the account has no expiration timestamp, otherwise int
|
||||
representation of the timestamp (as a number of milliseconds since epoch).
|
||||
"""
|
||||
res = yield self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="expiration_ts_ms",
|
||||
allow_none=True,
|
||||
desc="get_expiration_ts_for_user",
|
||||
)
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_account_validity_for_user(
|
||||
self, user_id, expiration_ts, email_sent, renewal_token=None
|
||||
):
|
||||
async def set_account_validity_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
expiration_ts: int,
|
||||
email_sent: bool,
|
||||
renewal_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Updates the account validity properties of the given account, with the
|
||||
given values.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the account to update properties for.
|
||||
expiration_ts (int): New expiration date, as a timestamp in milliseconds
|
||||
user_id: ID of the account to update properties for.
|
||||
expiration_ts: New expiration date, as a timestamp in milliseconds
|
||||
since epoch.
|
||||
email_sent (bool): True means a renewal email has been sent for this
|
||||
account and there's no need to send another one for the current validity
|
||||
email_sent: True means a renewal email has been sent for this account
|
||||
and there's no need to send another one for the current validity
|
||||
period.
|
||||
renewal_token (str): Renewal token the user can use to extend the validity
|
||||
renewal_token: Renewal token the user can use to extend the validity
|
||||
of their account. Defaults to no token.
|
||||
"""
|
||||
|
||||
|
@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
txn, self.get_expiration_ts_for_user, (user_id,)
|
||||
)
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"set_account_validity_for_user", set_account_validity_for_user_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_renewal_token_for_user(self, user_id, renewal_token):
|
||||
async def set_renewal_token_for_user(
|
||||
self, user_id: str, renewal_token: str
|
||||
) -> None:
|
||||
"""Defines a renewal token for a given user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to set the renewal token for.
|
||||
renewal_token (str): Random unique string that will be used to renew the
|
||||
user_id: ID of the user to set the renewal token for.
|
||||
renewal_token: Random unique string that will be used to renew the
|
||||
user's account.
|
||||
|
||||
Raises:
|
||||
StoreError: The provided token is already set for another user.
|
||||
"""
|
||||
yield self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_update_one(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"renewal_token": renewal_token},
|
||||
desc="set_renewal_token_for_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_renewal_token(self, renewal_token):
|
||||
async def get_user_from_renewal_token(self, renewal_token: str) -> str:
|
||||
"""Get a user ID from a renewal token.
|
||||
|
||||
Args:
|
||||
renewal_token (str): The renewal token to perform the lookup with.
|
||||
renewal_token: The renewal token to perform the lookup with.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The ID of the user to which the token belongs.
|
||||
The ID of the user to which the token belongs.
|
||||
"""
|
||||
res = yield self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"renewal_token": renewal_token},
|
||||
retcol="user_id",
|
||||
desc="get_user_from_renewal_token",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_renewal_token_for_user(self, user_id):
|
||||
async def get_renewal_token_for_user(self, user_id: str) -> str:
|
||||
"""Get the renewal token associated with a given user ID.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to lookup a token for.
|
||||
user_id: The user ID to lookup a token for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The renewal token associated with this user ID.
|
||||
The renewal token associated with this user ID.
|
||||
"""
|
||||
res = yield self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="renewal_token",
|
||||
desc="get_renewal_token_for_user",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_expiring_soon(self):
|
||||
async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
|
||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||
window (see configuration for account_validity for information on what renew_at
|
||||
refers to).
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
|
||||
A list of dictionaries mapping user ID to expiration time (in milliseconds).
|
||||
"""
|
||||
|
||||
def select_users_txn(txn, now_ms, renew_at):
|
||||
|
@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, values)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
res = yield self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_expiring_soon",
|
||||
select_users_txn,
|
||||
self.clock.time_msec(),
|
||||
self.config.account_validity.renew_at,
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_renewal_mail_status(self, user_id, email_sent):
|
||||
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
|
||||
"""Sets or unsets the flag that indicates whether a renewal email has been sent
|
||||
to the user (and the user hasn't renewed their account yet).
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to set/unset the flag for.
|
||||
email_sent (bool): Flag which indicates whether a renewal email has been sent
|
||||
user_id: ID of the user to set/unset the flag for.
|
||||
email_sent: Flag which indicates whether a renewal email has been sent
|
||||
to this user.
|
||||
"""
|
||||
yield self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_update_one(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"email_sent": email_sent},
|
||||
desc="set_renewal_mail_status",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_account_validity_for_user(self, user_id):
|
||||
async def delete_account_validity_for_user(self, user_id: str) -> None:
|
||||
"""Deletes the entry for the given user in the account validity table, removing
|
||||
their expiration date and renewal token.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to remove from the account validity table.
|
||||
user_id: ID of the user to remove from the account validity table.
|
||||
"""
|
||||
yield self.db_pool.simple_delete_one(
|
||||
await self.db_pool.simple_delete_one(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="delete_account_validity_for_user",
|
||||
)
|
||||
|
||||
async def is_server_admin(self, user):
|
||||
async def is_server_admin(self, user: UserID) -> bool:
|
||||
"""Determines if a user is an admin of this homeserver.
|
||||
|
||||
Args:
|
||||
user (UserID): user ID of the user to test
|
||||
user: user ID of the user to test
|
||||
|
||||
Returns (bool):
|
||||
Returns:
|
||||
true iff the user is a server admin, false otherwise.
|
||||
"""
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
|
@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
return None
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def is_real_user(self, user_id):
|
||||
@cached()
|
||||
async def is_real_user(self, user_id: str) -> bool:
|
||||
"""Determines if the user is a real user, ie does not have a 'user_type'.
|
||||
|
||||
Args:
|
||||
user_id (str): user id to test
|
||||
user_id: user id to test
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if user 'user_type' is null or empty string
|
||||
True if user 'user_type' is null or empty string
|
||||
"""
|
||||
res = yield self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"is_real_user", self.is_real_user_txn, user_id
|
||||
)
|
||||
return res
|
||||
|
||||
@cached()
|
||||
def is_support_user(self, user_id):
|
||||
async def is_support_user(self, user_id: str) -> bool:
|
||||
"""Determines if the user is of type UserTypes.SUPPORT
|
||||
|
||||
Args:
|
||||
user_id (str): user id to test
|
||||
user_id: user id to test
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if user is of type UserTypes.SUPPORT
|
||||
True if user is of type UserTypes.SUPPORT
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"is_support_user", self.is_support_user_txn, user_id
|
||||
)
|
||||
|
||||
|
@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="get_user_by_external_id",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_all_users(self):
|
||||
async def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
def _count_users(txn):
|
||||
|
@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
return rows[0]["users"]
|
||||
return 0
|
||||
|
||||
ret = yield self.db_pool.runInteraction("count_users", _count_users)
|
||||
return ret
|
||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||
|
||||
def count_daily_user_type(self):
|
||||
"""
|
||||
|
@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
"count_daily_user_type", _count_daily_user_type
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_nonbridged_users(self):
|
||||
async def count_nonbridged_users(self):
|
||||
def _count_users(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
|
@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
ret = yield self.db_pool.runInteraction("count_users", _count_users)
|
||||
return ret
|
||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_real_users(self):
|
||||
async def count_real_users(self):
|
||||
"""Counts all users without a special user_type registered on the homeserver."""
|
||||
|
||||
def _count_users(txn):
|
||||
|
@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
return rows[0]["users"]
|
||||
return 0
|
||||
|
||||
ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
|
||||
return ret
|
||||
return await self.db_pool.runInteraction("count_real_users", _count_users)
|
||||
|
||||
async def generate_user_id(self) -> str:
|
||||
"""Generate a suitable localpart for a guest user
|
||||
|
@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
return ret["user_id"]
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
yield self.db_pool.simple_upsert(
|
||||
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
await self.db_pool.simple_upsert(
|
||||
"user_threepids",
|
||||
{"medium": medium, "address": address},
|
||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_get_threepids(self, user_id):
|
||||
ret = yield self.db_pool.simple_select_list(
|
||||
async def user_get_threepids(self, user_id):
|
||||
return await self.db_pool.simple_select_list(
|
||||
"user_threepids",
|
||||
{"user_id": user_id},
|
||||
["medium", "address", "validated_at", "added_at"],
|
||||
"user_get_threepids",
|
||||
)
|
||||
return ret
|
||||
|
||||
def user_delete_threepid(self, user_id, medium, address):
|
||||
return self.db_pool.simple_delete(
|
||||
|
@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="get_id_servers_user_bound",
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_user_deactivated_status(self, user_id):
|
||||
@cached()
|
||||
async def get_user_deactivated_status(self, user_id: str) -> bool:
|
||||
"""Retrieve the value for the `deactivated` property for the provided user.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user to retrieve the status for.
|
||||
user_id: The ID of the user to retrieve the status for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred(bool): The requested value.
|
||||
True if the user was deactivated, false if the user is still active.
|
||||
"""
|
||||
|
||||
res = yield self.db_pool.simple_select_one_onecol(
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="deactivated",
|
||||
|
@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
for each of them.
|
||||
"""
|
||||
|
@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
else:
|
||||
return False, len(rows)
|
||||
|
||||
end, nb_processed = yield self.db_pool.runInteraction(
|
||||
end, nb_processed = await self.db_pool.runInteraction(
|
||||
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
|
||||
)
|
||||
|
||||
if end:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"users_set_deactivated_flag"
|
||||
)
|
||||
|
||||
return nb_processed
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||
async def _bg_user_threepids_grandfather(self, progress, batch_size):
|
||||
"""We now track which identity servers a user binds their 3PID to, so
|
||||
we need to handle the case of existing bindings where we didn't track
|
||||
this.
|
||||
|
@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
||||
|
||||
if id_servers:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
|
||||
)
|
||||
|
||||
yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
|
||||
await self.db_pool.updates._end_background_update("user_threepids_grandfather")
|
||||
|
||||
return 1
|
||||
|
||||
|
@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
|
||||
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
|
||||
async def add_access_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
device_id: Optional[str],
|
||||
valid_until_ms: Optional[int],
|
||||
) -> None:
|
||||
"""Adds an access token for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
token (str): The new access token to add.
|
||||
device_id (str): ID of the device to associate with the access
|
||||
token
|
||||
valid_until_ms (int|None): when the token is valid until. None for
|
||||
no expiry.
|
||||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the access token
|
||||
valid_until_ms: when the token is valid until. None for no expiry.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
"""
|
||||
next_id = self._access_tokens_id_gen.get_next()
|
||||
|
||||
yield self.db_pool.simple_insert(
|
||||
await self.db_pool.simple_insert(
|
||||
"access_tokens",
|
||||
{
|
||||
"id": next_id,
|
||||
|
@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
def record_user_external_id(
|
||||
self, auth_provider: str, external_id: str, user_id: str
|
||||
|
@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
|
||||
return self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def is_guest(self, user_id):
|
||||
res = yield self.db_pool.simple_select_one_onecol(
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
|
@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
self.clock.time_msec(),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_user_deactivated_status(self, user_id, deactivated):
|
||||
async def set_user_deactivated_status(
|
||||
self, user_id: str, deactivated: bool
|
||||
) -> None:
|
||||
"""Set the `deactivated` property for the provided user to the provided value.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user to set the status for.
|
||||
deactivated (bool): The value to set for `deactivated`.
|
||||
user_id: The ID of the user to set the status for.
|
||||
deactivated: The value to set for `deactivated`.
|
||||
"""
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"set_user_deactivated_status",
|
||||
self.set_user_deactivated_status_txn,
|
||||
user_id,
|
||||
|
@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_deactivated_status, (user_id,)
|
||||
)
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _set_expiration_date_when_missing(self):
|
||||
async def _set_expiration_date_when_missing(self):
|
||||
"""
|
||||
Retrieves the list of registered users that don't have an expiration date, and
|
||||
adds an expiration date for each of them.
|
||||
|
@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
txn, user["name"], use_delta=True
|
||||
)
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"get_users_with_no_expiration_date",
|
||||
select_users_with_no_expiration_date_txn,
|
||||
)
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
|||
columns=["room_id"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_deduplicate_state(self, progress, batch_size):
|
||||
async def _background_deduplicate_state(self, progress, batch_size):
|
||||
"""This background update will slowly deduplicate state by reencoding
|
||||
them as deltas.
|
||||
"""
|
||||
|
@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
|||
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
|
||||
|
||||
if max_group is None:
|
||||
rows = yield self.db_pool.execute(
|
||||
rows = await self.db_pool.execute(
|
||||
"_background_deduplicate_state",
|
||||
None,
|
||||
"SELECT coalesce(max(id), 0) FROM state_groups",
|
||||
|
@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
|||
|
||||
return False, batch_size
|
||||
|
||||
finished, result = yield self.db_pool.runInteraction(
|
||||
finished, result = await self.db_pool.runInteraction(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
|
||||
)
|
||||
|
||||
if finished:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
|
||||
)
|
||||
|
||||
return result * BATCH_SIZE_SCALE_FACTOR
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_index_state(self, progress, batch_size):
|
||||
async def _background_index_state(self, progress, batch_size):
|
||||
def reindex_txn(conn):
|
||||
conn.rollback()
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
|
@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
|||
)
|
||||
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
|
||||
|
||||
yield self.db_pool.runWithConnection(reindex_txn)
|
||||
await self.db_pool.runWithConnection(reindex_txn)
|
||||
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.STATE_GROUP_INDEX_UPDATE_NAME
|
||||
)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError
|
|||
from synapse.handlers.register import RegistrationHandler
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import override_config
|
||||
|
||||
from .. import unittest
|
||||
|
@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
room_alias_str = "#room:test"
|
||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(False))
|
||||
self.store.is_real_user = Mock(return_value=make_awaitable(False))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 0)
|
||||
|
@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
|
||||
room_alias_str = "#room:test"
|
||||
|
||||
self.store.count_real_users = Mock(return_value=defer.succeed(1))
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||
self.store.count_real_users = Mock(return_value=make_awaitable(1))
|
||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
|
@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
room_alias_str = "#room:test"
|
||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||
|
||||
self.store.count_real_users = Mock(return_value=defer.succeed(2))
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||
self.store.count_real_users = Mock(return_value=make_awaitable(2))
|
||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 0)
|
||||
|
|
|
@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(self.store.register_user(user_id=user2, password_hash=None))
|
||||
|
||||
now = int(self.hs.get_clock().time_msec())
|
||||
self.store.user_add_threepid(user1, "email", user1_email, now, now)
|
||||
self.store.user_add_threepid(user2, "email", user2_email, now, now)
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(user1, "email", user1_email, now, now)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(user2, "email", user2_email, now, now)
|
||||
)
|
||||
|
||||
users = self.get_success(self.store.get_registered_reserved_users())
|
||||
self.assertEqual(len(users), len(threepids))
|
||||
|
|
|
@ -58,8 +58,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_add_tokens(self):
|
||||
yield self.store.register_user(self.user_id, self.pwhash)
|
||||
yield self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
|
||||
yield defer.ensureDeferred(
|
||||
self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
result = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||
|
@ -74,11 +76,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
def test_user_delete_access_tokens(self):
|
||||
# add some tokens
|
||||
yield self.store.register_user(self.user_id, self.pwhash)
|
||||
yield self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
|
||||
yield defer.ensureDeferred(
|
||||
self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
yield self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
|
||||
yield defer.ensureDeferred(
|
||||
self.store.add_access_token_to_user(
|
||||
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
# now delete some
|
||||
|
|
Loading…
Reference in New Issue