Add some type hints to datastore (#12423)
* Add some type hints to datastore * newsfile * change `Collection` to `List` * refactor return type of `select_users_txn` * correct type hint in `stream.py` * Remove `Optional` in `select_users_txn` * remove not needed return type in `__init__` * Revert change in `get_stream_id_for_event_txn` * Remove import from `Literal`pull/12344/head
parent
4e13743738
commit
1783156dbc
|
@ -0,0 +1 @@
|
||||||
|
Add some type hints to datastore.
|
|
@ -180,9 +180,9 @@ class AccountValidityHandler:
|
||||||
expiring_users = await self.store.get_users_expiring_soon()
|
expiring_users = await self.store.get_users_expiring_soon()
|
||||||
|
|
||||||
if expiring_users:
|
if expiring_users:
|
||||||
for user in expiring_users:
|
for user_id, expiration_ts_ms in expiring_users:
|
||||||
await self._send_renewal_email(
|
await self._send_renewal_email(
|
||||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
user_id=user_id, expiration_ts=expiration_ts_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_renewal_email_to_user(self, user_id: str) -> None:
|
async def send_renewal_email_to_user(self, user_id: str) -> None:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
|
||||||
|
|
||||||
from synapse.appservice import (
|
from synapse.appservice import (
|
||||||
ApplicationService,
|
ApplicationService,
|
||||||
|
@ -26,7 +26,11 @@ from synapse.appservice import (
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import db_to_json
|
from synapse.storage._base import db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
|
@ -92,7 +96,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
def get_app_services(self):
|
def get_app_services(self) -> List[ApplicationService]:
|
||||||
return self.services_cache
|
return self.services_cache
|
||||||
|
|
||||||
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||||
|
@ -256,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
A new transaction.
|
A new transaction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_appservice_txn(txn):
|
def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
|
||||||
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
|
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
|
||||||
|
|
||||||
# Insert new txn into txn table
|
# Insert new txn into txn table
|
||||||
|
@ -291,7 +295,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
service: The application service which was sent this transaction.
|
service: The application service which was sent this transaction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _complete_appservice_txn(txn):
|
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
|
||||||
# Set current txn_id for AS to 'txn_id'
|
# Set current txn_id for AS to 'txn_id'
|
||||||
self.db_pool.simple_upsert_txn(
|
self.db_pool.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -322,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
An AppServiceTransaction or None.
|
An AppServiceTransaction or None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_oldest_unsent_txn(txn):
|
def _get_oldest_unsent_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
# Monotonically increasing txn ids, so just select the smallest
|
# Monotonically increasing txn ids, so just select the smallest
|
||||||
# one in the txns table (we delete them when they are sent)
|
# one in the txns table (we delete them when they are sent)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -364,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_appservice_last_pos(self, pos: int) -> None:
|
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||||
def set_appservice_last_pos_txn(txn):
|
def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||||
)
|
)
|
||||||
|
@ -378,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
) -> Tuple[int, List[EventBase]]:
|
) -> Tuple[int, List[EventBase]]:
|
||||||
"""Get all new events for an appservice"""
|
"""Get all new events for an appservice"""
|
||||||
|
|
||||||
def get_new_events_for_appservice_txn(txn):
|
def get_new_events_for_appservice_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, e.event_id"
|
"SELECT e.stream_ordering, e.event_id"
|
||||||
" FROM events AS e"
|
" FROM events AS e"
|
||||||
|
@ -416,7 +424,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
% (type,)
|
% (type,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_type_stream_id_for_appservice_txn(txn):
|
def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
|
||||||
stream_id_type = "%s_stream_id" % type
|
stream_id_type = "%s_stream_id" % type
|
||||||
txn.execute(
|
txn.execute(
|
||||||
# We do NOT want to escape `stream_id_type`.
|
# We do NOT want to escape `stream_id_type`.
|
||||||
|
@ -444,7 +452,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||||
% (stream_type,)
|
% (stream_type,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_appservice_stream_type_pos_txn(txn):
|
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
|
||||||
stream_id_type = "%s_stream_id" % stream_type
|
stream_id_type = "%s_stream_id" % stream_type
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
||||||
|
|
|
@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import IdGenerator
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import UserID, UserInfo
|
from synapse.types import JsonDict, UserID, UserInfo
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -79,7 +79,7 @@ class TokenLookupResult:
|
||||||
|
|
||||||
# Make the token owner default to the user ID, which is the common case.
|
# Make the token owner default to the user ID, which is the common case.
|
||||||
@token_owner.default
|
@token_owner.default
|
||||||
def _default_token_owner(self):
|
def _default_token_owner(self) -> str:
|
||||||
return self.user_id
|
return self.user_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
the account.
|
the account.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_account_validity_for_user_txn(txn):
|
def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_txn(
|
self.db_pool.simple_update_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="account_validity",
|
table="account_validity",
|
||||||
|
@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
desc="get_renewal_token_for_user",
|
desc="get_renewal_token_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
|
async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
|
||||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||||
window (see configuration for account_validity for information on what renew_at
|
window (see configuration for account_validity for information on what renew_at
|
||||||
refers to).
|
refers to).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries, each with a user ID and expiration time (in milliseconds).
|
A list of tuples, each with a user ID and expiration time (in milliseconds).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_users_txn(txn, now_ms, renew_at):
|
def select_users_txn(
|
||||||
|
txn: LoggingTransaction, now_ms: int, renew_at: int
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT user_id, expiration_ts_ms FROM account_validity"
|
"SELECT user_id, expiration_ts_ms FROM account_validity"
|
||||||
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
|
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
|
||||||
)
|
)
|
||||||
values = [False, now_ms, renew_at]
|
values = [False, now_ms, renew_at]
|
||||||
txn.execute(sql, values)
|
txn.execute(sql, values)
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_users_expiring_soon",
|
"get_users_expiring_soon",
|
||||||
|
@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
admin: true iff the user is to be a server admin, false otherwise.
|
admin: true iff the user is to be a server admin, false otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_server_admin_txn(txn):
|
def set_server_admin_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
|
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
|
||||||
)
|
)
|
||||||
|
@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
user_type: type of the user or None for a user without a type.
|
user_type: type of the user or None for a user without a type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_user_type_txn(txn):
|
def set_user_type_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
|
txn, "users", {"name": user.to_string()}, {"user_type": user_type}
|
||||||
)
|
)
|
||||||
|
@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
|
await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
def _query_for_auth(
|
||||||
|
self, txn: LoggingTransaction, token: str
|
||||||
|
) -> Optional[TokenLookupResult]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT users.name as user_id,
|
SELECT users.name as user_id,
|
||||||
users.is_guest,
|
users.is_guest,
|
||||||
|
@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"is_support_user", self.is_support_user_txn, user_id
|
"is_support_user", self.is_support_user_txn, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_real_user_txn(self, txn, user_id):
|
def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
|
||||||
res = self.db_pool.simple_select_one_onecol_txn(
|
res = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="users",
|
table="users",
|
||||||
|
@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
return res is None
|
return res is None
|
||||||
|
|
||||||
def is_support_user_txn(self, txn, user_id):
|
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
|
||||||
res = self.db_pool.simple_select_one_onecol_txn(
|
res = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="users",
|
table="users",
|
||||||
|
@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
A mapping of user_id -> password_hash.
|
A mapping of user_id -> password_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> Dict[str, str]:
|
||||||
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
|
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return dict(txn)
|
result = cast(List[Tuple[str, str]], txn.fetchall())
|
||||||
|
return dict(result)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
|
@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
def _replace_user_external_id_txn(
|
def _replace_user_external_id_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
):
|
) -> None:
|
||||||
_remove_user_external_ids_txn(txn, user_id)
|
_remove_user_external_ids_txn(txn, user_id)
|
||||||
|
|
||||||
for auth_provider, external_id in record_external_ids:
|
for auth_provider, external_id in record_external_ids:
|
||||||
|
@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||||
|
|
||||||
async def count_all_users(self):
|
async def count_all_users(self) -> int:
|
||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
|
||||||
def _count_users(txn):
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
txn.execute("SELECT COUNT(*) AS users FROM users")
|
txn.execute("SELECT COUNT(*) AS users FROM users")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
if rows:
|
if rows:
|
||||||
|
@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
who registered on the homeserver in the past 24 hours
|
who registered on the homeserver in the past 24 hours
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_daily_user_type(txn):
|
def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
|
||||||
yesterday = int(self._clock.time()) - (60 * 60 * 24)
|
yesterday = int(self._clock.time()) - (60 * 60 * 24)
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
"count_daily_user_type", _count_daily_user_type
|
"count_daily_user_type", _count_daily_user_type
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_nonbridged_users(self):
|
async def count_nonbridged_users(self) -> int:
|
||||||
def _count_users(txn):
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT COUNT(*) FROM users
|
SELECT COUNT(*) FROM users
|
||||||
WHERE appservice_id IS NULL
|
WHERE appservice_id IS NULL
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
return count
|
return count
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||||
|
|
||||||
async def count_real_users(self):
|
async def count_real_users(self) -> int:
|
||||||
"""Counts all users without a special user_type registered on the homeserver."""
|
"""Counts all users without a special user_type registered on the homeserver."""
|
||||||
|
|
||||||
def _count_users(txn):
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
|
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
if rows:
|
if rows:
|
||||||
|
@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
def get_user_id_by_threepid_txn(
|
def get_user_id_by_threepid_txn(
|
||||||
self, txn, medium: str, address: str
|
self, txn: LoggingTransaction, medium: str, address: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Returns user id from threepid
|
"""Returns user id from threepid
|
||||||
|
|
||||||
|
@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
|
async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
|
||||||
return await self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
|
@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
async def add_user_bound_threepid(
|
async def add_user_bound_threepid(
|
||||||
self, user_id: str, medium: str, address: str, id_server: str
|
self, user_id: str, medium: str, address: str, id_server: str
|
||||||
):
|
) -> None:
|
||||||
"""The server proxied a bind request to the given identity server on
|
"""The server proxied a bind request to the given identity server on
|
||||||
behalf of the given user. We need to remember this in case the user
|
behalf of the given user. We need to remember this in case the user
|
||||||
asks us to unbind the threepid.
|
asks us to unbind the threepid.
|
||||||
|
@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
assert address or sid
|
assert address or sid
|
||||||
|
|
||||||
def get_threepid_validation_session_txn(txn):
|
def get_threepid_validation_session_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT address, session_id, medium, client_secret,
|
SELECT address, session_id, medium, client_secret,
|
||||||
last_send_attempt, validated_at
|
last_send_attempt, validated_at
|
||||||
|
@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
session_id: The ID of the session to delete
|
session_id: The ID of the session to delete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete_threepid_session_txn(txn):
|
def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="threepid_validation_token",
|
table="threepid_validation_token",
|
||||||
|
@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
async def cull_expired_threepid_validation_tokens(self) -> None:
|
async def cull_expired_threepid_validation_tokens(self) -> None:
|
||||||
"""Remove threepid validation tokens with expiry dates that have passed"""
|
"""Remove threepid validation tokens with expiry dates that have passed"""
|
||||||
|
|
||||||
def cull_expired_threepid_validation_tokens_txn(txn, ts):
|
def cull_expired_threepid_validation_tokens_txn(
|
||||||
|
txn: LoggingTransaction, ts: int
|
||||||
|
) -> None:
|
||||||
sql = """
|
sql = """
|
||||||
DELETE FROM threepid_validation_token WHERE
|
DELETE FROM threepid_validation_token WHERE
|
||||||
expires < ?
|
expires < ?
|
||||||
|
@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("account_validity_set_expiration_dates")
|
@wrap_as_background_process("account_validity_set_expiration_dates")
|
||||||
async def _set_expiration_date_when_missing(self):
|
async def _set_expiration_date_when_missing(self) -> None:
|
||||||
"""
|
"""
|
||||||
Retrieves the list of registered users that don't have an expiration date, and
|
Retrieves the list of registered users that don't have an expiration date, and
|
||||||
adds an expiration date for each of them.
|
adds an expiration date for each of them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_users_with_no_expiration_date_txn(txn):
|
def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
|
||||||
"""Retrieves the list of registered users with no expiration date from the
|
"""Retrieves the list of registered users with no expiration date from the
|
||||||
database, filtering out deactivated users.
|
database, filtering out deactivated users.
|
||||||
"""
|
"""
|
||||||
|
@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
select_users_with_no_expiration_date_txn,
|
select_users_with_no_expiration_date_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
|
def set_expiration_date_for_user_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
|
||||||
|
) -> None:
|
||||||
"""Sets an expiration date to the account with the given user ID.
|
"""Sets an expiration date to the account with the given user ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
token: The registration token pending use
|
token: The registration token pending use
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _set_registration_token_pending_txn(txn):
|
def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
|
||||||
pending = self.db_pool.simple_select_one_onecol_txn(
|
pending = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
|
@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
updatevalues={"pending": pending + 1},
|
updatevalues={"pending": pending + 1},
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"set_registration_token_pending", _set_registration_token_pending_txn
|
"set_registration_token_pending", _set_registration_token_pending_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
token: The registration token to be 'used'
|
token: The registration token to be 'used'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _use_registration_token_txn(txn):
|
def _use_registration_token_txn(txn: LoggingTransaction) -> None:
|
||||||
# Normally, res is Optional[Dict[str, Any]].
|
# Normally, res is Optional[Dict[str, Any]].
|
||||||
# Override type because the return type is only optional if
|
# Override type because the return type is only optional if
|
||||||
# allow_none is True, and we don't want mypy throwing errors
|
# allow_none is True, and we don't want mypy throwing errors
|
||||||
|
@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"use_registration_token", _use_registration_token_txn
|
"use_registration_token", _use_registration_token_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
A list of dicts, each containing details of a token.
|
A list of dicts, each containing details of a token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
|
def select_registration_tokens_txn(
|
||||||
|
txn: LoggingTransaction, now: int, valid: Optional[bool]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
if valid is None:
|
if valid is None:
|
||||||
# Return all tokens regardless of validity
|
# Return all tokens regardless of validity
|
||||||
txn.execute("SELECT * FROM registration_tokens")
|
txn.execute("SELECT * FROM registration_tokens")
|
||||||
|
@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
Whether the row was inserted or not.
|
Whether the row was inserted or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_registration_token_txn(txn):
|
def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
|
||||||
row = self.db_pool.simple_select_one_txn(
|
row = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"registration_tokens",
|
"registration_tokens",
|
||||||
|
@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
A dict with all info about the token, or None if token doesn't exist.
|
A dict with all info about the token, or None if token doesn't exist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _update_registration_token_txn(txn):
|
def _update_registration_token_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
) -> Optional[RefreshTokenLookupResult]:
|
) -> Optional[RefreshTokenLookupResult]:
|
||||||
"""Lookup a refresh token with hints about its validity."""
|
"""Lookup a refresh token with hints about its validity."""
|
||||||
|
|
||||||
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
|
def _lookup_refresh_token_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[RefreshTokenLookupResult]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1807,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
unique=False,
|
unique=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
async def _background_update_set_deactivated_flag(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||||
for each of them.
|
for each of them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_user = progress.get("user_id", "")
|
last_user = progress.get("user_id", "")
|
||||||
|
|
||||||
def _background_update_set_deactivated_flag_txn(txn):
|
def _background_update_set_deactivated_flag_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[bool, int]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1886,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
deactivated,
|
deactivated,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
|
def set_user_deactivated_status_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, deactivated: bool
|
||||||
|
) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="users",
|
table="users",
|
||||||
|
@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
def _set_device_for_access_token_txn(
|
||||||
|
self, txn: LoggingTransaction, token: str, device_id: str
|
||||||
|
) -> str:
|
||||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn, "access_tokens", {"token": token}, "device_id"
|
txn, "access_tokens", {"token": token}, "device_id"
|
||||||
)
|
)
|
||||||
|
@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
def _register_user(
|
def _register_user(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
password_hash: Optional[str],
|
password_hash: Optional[str],
|
||||||
was_guest: bool,
|
was_guest: bool,
|
||||||
|
@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
admin: bool,
|
admin: bool,
|
||||||
user_type: Optional[str],
|
user_type: Optional[str],
|
||||||
shadow_banned: bool,
|
shadow_banned: bool,
|
||||||
):
|
) -> None:
|
||||||
user_id_obj = UserID.from_string(user_id)
|
user_id_obj = UserID.from_string(user_id)
|
||||||
|
|
||||||
now = int(self._clock.time())
|
now = int(self._clock.time())
|
||||||
|
@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
pointless. Use flush_user separately.
|
pointless. Use flush_user separately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def user_set_password_hash_txn(txn):
|
def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn, "users", {"name": user_id}, {"password_hash": password_hash}
|
txn, "users", {"name": user_id}, {"password_hash": password_hash}
|
||||||
)
|
)
|
||||||
|
@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
StoreError(404) if user not found
|
StoreError(404) if user not found
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="users",
|
table="users",
|
||||||
|
@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
StoreError(404) if user not found
|
StoreError(404) if user not found
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_update_one_txn(
|
self.db_pool.simple_update_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="users",
|
table="users",
|
||||||
|
@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
|
||||||
keyvalues = {"user_id": user_id}
|
keyvalues = {"user_id": user_id}
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
keyvalues["device_id"] = device_id
|
keyvalues["device_id"] = device_id
|
||||||
|
@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||||
|
|
||||||
async def delete_access_token(self, access_token: str) -> None:
|
async def delete_access_token(self, access_token: str) -> None:
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_one_txn(
|
self.db_pool.simple_delete_one_txn(
|
||||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||||
)
|
)
|
||||||
|
@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
await self.db_pool.runInteraction("delete_access_token", f)
|
await self.db_pool.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
async def delete_refresh_token(self, refresh_token: str) -> None:
|
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_one_txn(
|
self.db_pool.simple_delete_one_txn(
|
||||||
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||||
)
|
)
|
||||||
|
@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Insert everything into a transaction in order to run atomically
|
# Insert everything into a transaction in order to run atomically
|
||||||
def validate_threepid_session_txn(txn):
|
def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||||
row = self.db_pool.simple_select_one_txn(
|
row = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="threepid_validation_session",
|
table="threepid_validation_session",
|
||||||
|
@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
longer be valid
|
longer be valid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def start_or_continue_validation_session_txn(txn):
|
def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
|
||||||
# Create or update a validation session
|
# Create or update a validation session
|
||||||
self.db_pool.simple_upsert_txn(
|
self.db_pool.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
|
@ -742,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
%s;
|
%s;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_if_events_have_relations(txn) -> List[str]:
|
def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
|
||||||
clauses: List[str] = []
|
clauses: List[str] = []
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "relates_to_id", parent_ids
|
txn.database_engine, "relates_to_id", parent_ids
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
class SignatureWorkerStore(EventsWorkerStore):
|
class SignatureWorkerStore(EventsWorkerStore):
|
||||||
@cached()
|
@cached()
|
||||||
def get_event_reference_hash(self, event_id):
|
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
|
||||||
# This is a dummy function to allow get_event_reference_hashes
|
# This is a dummy function to allow get_event_reference_hashes
|
||||||
# to use its cache
|
# to use its cache
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -204,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
The current state of the room.
|
The current state of the room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_current_state_ids_txn(txn):
|
def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""SELECT type, state_key, event_id FROM current_state_events
|
"""SELECT type, state_key, event_id FROM current_state_events
|
||||||
WHERE room_id = ?
|
WHERE room_id = ?
|
||||||
|
|
|
@ -36,7 +36,17 @@ what sort order was used:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
@ -732,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
A tuple of (stream ordering, topological ordering, event_id)
|
A tuple of (stream ordering, topological ordering, event_id)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _f(txn):
|
def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_ordering, topological_ordering, event_id"
|
"SELECT stream_ordering, topological_ordering, event_id"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
|
@ -742,7 +752,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
" LIMIT 1"
|
" LIMIT 1"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (room_id, stream_ordering))
|
txn.execute(sql, (room_id, stream_ordering))
|
||||||
return txn.fetchone()
|
return cast(Optional[Tuple[int, int, str]], txn.fetchone())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_room_event_before_stream_ordering", _f
|
"get_room_event_before_stream_ordering", _f
|
||||||
|
@ -839,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_before_and_after(
|
def _set_before_and_after(
|
||||||
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
|
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
|
||||||
):
|
) -> None:
|
||||||
"""Inserts ordering information to events' internal metadata from
|
"""Inserts ordering information to events' internal metadata from
|
||||||
the DB rows.
|
the DB rows.
|
||||||
|
|
||||||
|
@ -985,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
the `current_id`).
|
the `current_id`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_all_new_events_stream_txn(txn):
|
def get_all_new_events_stream_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, e.event_id"
|
"SELECT e.stream_ordering, e.event_id"
|
||||||
" FROM events AS e"
|
" FROM events AS e"
|
||||||
|
@ -1331,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
async def get_id_for_instance(self, instance_name: str) -> int:
|
async def get_id_for_instance(self, instance_name: str) -> int:
|
||||||
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
|
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
|
||||||
|
|
||||||
def _get_id_for_instance_txn(txn):
|
def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
|
||||||
instance_id = self.db_pool.simple_select_one_onecol_txn(
|
instance_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table="instance_map",
|
table="instance_map",
|
||||||
|
|
|
@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tag_content(
|
def get_tag_content(
|
||||||
txn: LoggingTransaction, tag_ids
|
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
|
||||||
) -> List[Tuple[int, Tuple[str, str, str]]]:
|
) -> List[Tuple[int, Tuple[str, str, str]]]:
|
||||||
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
|
||||||
results = []
|
results = []
|
||||||
|
@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
def _update_revision_txn(
|
def _update_revision_txn(
|
||||||
self, txn, user_id: str, room_id: str, next_id: int
|
self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the latest revision of the tags for the given user and room.
|
"""Update the latest revision of the tags for the given user and room.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue