Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes
Erik Johnston 2023-09-14 16:21:58 +01:00
commit 1e0b96f1a4
43 changed files with 341 additions and 242 deletions

1
changelog.d/16288.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.

1
changelog.d/16301.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

1
changelog.d/16304.doc Normal file
View File

@ -0,0 +1 @@
Link to the Alpine Linux community package for Synapse.

1
changelog.d/16313.misc Normal file
View File

@ -0,0 +1 @@
Delete device messages asynchronously and in staged batches using the task scheduler.

1
changelog.d/16314.misc Normal file
View File

@ -0,0 +1 @@
Remove a reference cycle for in background processes.

1
changelog.d/16316.misc Normal file
View File

@ -0,0 +1 @@
Refactor `get_user_by_id`.

1
changelog.d/16318.misc Normal file
View File

@ -0,0 +1 @@
Speed up task to delete to-device messages.

View File

@ -155,6 +155,14 @@ sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt
```
#### Alpine Linux
6543 maintains [Synapse packages for Alpine Linux](https://pkgs.alpinelinux.org/packages?name=synapse&branch=edge) in the community repository. Install with:
```sh
sudo apk add synapse
```
#### Void Linux
Synapse can be found in the void repositories as

View File

@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)

View File

@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname)
# First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen

View File

@ -27,9 +27,7 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
NoReturn,
Optional,
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.types import ISynapseReactor
from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
"""
Start Prometheus metrics server.
"""
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
def listen_manhole(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
@ -339,7 +337,7 @@ def listen_manhole(
def listen_tcp(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
@ -448,7 +446,7 @@ def listen_http(
def listen_ssl(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,

View File

@ -26,7 +26,6 @@ from textwrap import dedent
from typing import (
Any,
ClassVar,
Collection,
Dict,
Iterable,
Iterator,
@ -384,7 +383,7 @@ class RootConfig:
config_classes: List[Type[Config]] = []
def __init__(self, config_files: Collection[str] = ()):
def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files]

View File

@ -25,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]:
return self._dict.keys()
def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict
from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
prev_event_ids: Collection[str],
prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Iterable, List, Type, Union, cast
from typing import List, Type, Union, cast
import jsonschema
from pydantic import Field, StrictBool, StrictStr
@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
class EventValidator:
@ -225,7 +225,7 @@ class EventValidator:
self._ensure_state_event(event)
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))

View File

@ -102,7 +102,7 @@ class AccountHandler:
"""
status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None:
status = {

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@ -57,38 +57,30 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None
# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
"last_seen_ts",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved
# Add additional user metadata
profile = await self._store.get_profileinfo(user)
@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:

View File

@ -388,7 +388,8 @@ class DeviceWorkerHandler:
"Trying handling device list state for partial join: not supported on workers."
)
DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
DEVICE_MSGS_DELETE_SLEEP_MS = 1000
async def _delete_device_messages(
self,
@ -400,6 +401,8 @@ class DeviceWorkerHandler:
device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"]
# Delete the messages in batches to avoid too much DB load.
while True:
res = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
@ -409,10 +412,8 @@ class DeviceWorkerHandler:
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
return TaskStatus.COMPLETE, None, None
else:
# There is probably still device messages to be deleted, let's keep the task active and it will be run
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
return TaskStatus.ACTIVE, None, None
await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
class DeviceHandler(DeviceWorkerHandler):
@ -758,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access
# token to use the dehydrated device's ID and copy the old device
# display name to the dehydrated device, and destroy the old device
# ID
# token and refresh token to use the dehydrated device's ID and
# copy the old device display name to the dehydrated device,
# and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token(
access_token, device_id
)
await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None:
raise errors.NotFoundError()

View File

@ -828,13 +828,13 @@ class EventCreationHandler:
u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
if u.appservice_id is not None:
# users registered by an appservice are exempt
return
if u["consent_version"] == self.config.consent.user_consent_version:
if u.consent_version == self.config.consent.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)

View File

@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
List[str],
StrSequence,
List[bytes],
List[Union[str, bytes]],
Tuple[str, ...],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
]

View File

@ -18,7 +18,6 @@ import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
List,
Mapping,
Optional,
@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -340,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@ -352,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@ -365,7 +364,7 @@ def parse_string(
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@ -376,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@ -485,7 +484,7 @@ def parse_enum(
def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
@ -610,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@ -635,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@ -646,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type)
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:

View File

@ -25,7 +25,6 @@ from typing import (
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION
logger = logging.getLogger(__name__)
@ -81,7 +81,7 @@ class LaterGauge(Collector):
name: str
desc: str
labels: Optional[Sequence[str]] = attr.ib(hash=False)
labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self,
name: str,
desc: str,
labels: Sequence[str],
sub_metrics: Sequence[str],
labels: StrSequence,
sub_metrics: StrSequence,
):
self.name = name
self.desc = desc

View File

@ -322,13 +322,21 @@ class BackgroundProcessLoggingContext(LoggingContext):
if instance_id is None:
instance_id = id(self)
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""
super().start(rusage)
if self._proc is None:
logger.error(
"Background process re-entered without a proc: %s",
self.name,
stack_info=True,
)
return
# We've become active again so we make sure we're in the list of active
# procs. (Note that "start" here means we've become active, as opposed
# to starting for the first time.)
@ -345,6 +353,14 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__exit__(type, value, traceback)
if self._proc is None:
logger.error(
"Background process exited without a proc: %s",
self.name,
stack_info=True,
)
return
# The background process has finished. We explicitly remove and manually
# update the metrics here so that if nothing is scraping metrics the set
# doesn't infinitely grow.
@ -352,3 +368,6 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.discard(self._proc)
self._proc.update_metrics()
# Set proc to None to break the reference cycle.
self._proc = None

View File

@ -572,7 +572,7 @@ class ModuleApi:
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
return await self._store.get_user_by_id(user_id)
async def get_user_by_req(
self,
@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
# Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None:
if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")
await self._handler.add_account_data_for_user(user_id, data_type, new_data)

View File

@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__(
self,
user_id: str,
rooms: Collection[str],
rooms: StrCollection,
current_token: StreamToken,
time_now_ms: int,
):
@ -457,7 +457,7 @@ class Notifier:
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[Collection[str]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
"""Used to inform listeners that something has happened event wise.
@ -529,7 +529,7 @@ class Notifier:
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids: Optional[Collection[str]] = None,
room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the

View File

@ -672,14 +672,12 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key
)
async def on_NEW_ACTIVE_TASK(
def on_NEW_ACTIVE_TASK(
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
) -> None:
"""Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler:
task = await self._task_scheduler.get_task(cmd.data)
if task:
await self._task_scheduler._launch_task(task)
self._task_scheduler.launch_task_by_id(cmd.data)
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""

View File

@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.types import JsonDict
from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__)
def client_patterns(
path_regex: str,
releases: Iterable[str] = ("r0", "v3"),
releases: StrCollection = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:

View File

@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None:
raise NotFoundError("Unknown user")
has_consented = u["consent_version"] == version
has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")
try:

View File

@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None:
return
if u["is_guest"] and not self._send_to_guests:
if u.is_guest and not self._send_to_guests:
# don't send to guests
return
if u["consent_version"] == self._current_consent_version:
if u.consent_version == self._current_consent_version:
# user has already consented
return
if u["consent_server_notice_sent"] == self._current_consent_version:
if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return

View File

@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
DefaultDict,
Dict,
FrozenSet,
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events(
self,
room_id: str,
event_ids: Collection[str],
event_ids: StrCollection,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: Collection[str]
self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
@ -470,7 +469,7 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database

View File

@ -17,7 +17,6 @@ import logging
from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:

View File

@ -19,7 +19,6 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Generator,
Iterable,
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
auth_difference_unpersisted_part: Collection[str] = union - intersection
auth_difference_unpersisted_part: StrCollection = union - intersection
else:
auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]

View File

@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
}
return list(results.values())
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""
return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)

View File

@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection
from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},

View File

@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts
COALESCE(locked, FALSE) AS locked
FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ?
""",
(user_id,),
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
if row is None:
return None
return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
last_seen_ts=user_data["last_seen_ts"],
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)
async def is_trial_user(self, user_id: str) -> bool:
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days
info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial
@cached()
@ -2312,6 +2280,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
async def set_device_for_refresh_token(
self, user_id: str, old_device_id: str, device_id: str
) -> None:
"""Moves refresh tokens from old device to current device
Args:
user_id: The user of the devices.
old_device_id: The old device.
device_id: The new device ID.
Returns:
None
"""
await self.db_pool.simple_update(
"refresh_tokens",
keyvalues={"user_id": user_id, "device_id": old_device_id},
updatevalues={"device_id": device_id},
desc="set_device_for_refresh_token",
)
def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:

View File

@ -53,6 +53,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get a list of scheduled tasks from the DB.
@ -62,6 +63,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one
limit: Only return `limit` number of rows if set.
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
@ -94,6 +96,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
sql = sql + " ORDER BY timestamp"
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn)

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp);

View File

@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
"""Holds information about a user. Result of get_user_by_id.
Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user.
approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
"""
user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
last_seen_ts: Optional[int]
approved: bool
locked: bool
class UserProfile(TypedDict):

View File

@ -15,12 +15,14 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
from prometheus_client import Gauge
from twisted.python.failure import Failure
from synapse.logging.context import nested_logging_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import JsonMapping, ScheduledTask, TaskStatus
from synapse.util.stringutils import random_string
@ -30,12 +32,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
running_tasks_gauge = Gauge(
"synapse_scheduler_running_tasks",
"The number of concurrent running tasks handled by the TaskScheduler",
)
class TaskScheduler:
"""
This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
@ -70,6 +66,8 @@ class TaskScheduler:
# Precision of the scheduler, evaluation of tasks to run will only happen
# every `SCHEDULE_INTERVAL_MS` ms
SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn
# How often to clean up old tasks.
CLEANUP_INTERVAL_MS = 30 * 60 * 1000
# Time before a complete or failed task is deleted from the DB
KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
# Maximum number of tasks that can run at the same time
@ -92,12 +90,24 @@ class TaskScheduler:
] = {}
self._run_background_tasks = hs.config.worker.run_background_tasks
# Flag to make sure we only try and launch new tasks once at a time.
self._launching_new_tasks = False
if self._run_background_tasks:
self._clock.looping_call(
run_as_background_process,
self._launch_scheduled_tasks,
TaskScheduler.SCHEDULE_INTERVAL_MS,
"handle_scheduled_tasks",
self._handle_scheduled_tasks,
)
self._clock.looping_call(
self._clean_scheduled_tasks,
TaskScheduler.SCHEDULE_INTERVAL_MS,
)
LaterGauge(
"synapse_scheduler_running_tasks",
"The number of concurrent running tasks handled by the TaskScheduler",
labels=None,
caller=lambda: len(self._running_tasks),
)
def register_action(
@ -234,6 +244,7 @@ class TaskScheduler:
resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get a list of tasks. Returns all the tasks if no args is provided.
@ -247,6 +258,7 @@ class TaskScheduler:
statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one
limit: Only return `limit` number of rows if set.
Returns
A list of `ScheduledTask`, ordered by increasing timestamps
@ -256,6 +268,7 @@ class TaskScheduler:
resource_id=resource_id,
statuses=statuses,
max_timestamp=max_timestamp,
limit=limit,
)
async def delete_task(self, id: str) -> None:
@ -273,33 +286,57 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id)
async def _handle_scheduled_tasks(self) -> None:
"""Main loop taking care of launching tasks and cleaning up old ones."""
await self._launch_scheduled_tasks()
await self._clean_scheduled_tasks()
def launch_task_by_id(self, id: str) -> None:
"""Try launching the task with the given ID."""
# Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
async def _launch_task_by_id(self, id: str) -> None:
"""Helper async function for `launch_task_by_id`."""
task = await self.get_task(id)
if task:
await self._launch_task(task)
@wrap_as_background_process("launch_scheduled_tasks")
async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time."""
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
# Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
if self._launching_new_tasks:
return
self._launching_new_tasks = True
try:
for task in await self.get_tasks(
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
):
await self._launch_task(task)
for task in await self.get_tasks(
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
statuses=[TaskStatus.SCHEDULED],
max_timestamp=self._clock.time_msec(),
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
):
await self._launch_task(task)
running_tasks_gauge.set(len(self._running_tasks))
finally:
self._launching_new_tasks = False
@wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None:
"""Clean old complete or failed jobs to avoid clutter the DB."""
now = self._clock.time_msec()
for task in await self._store.get_scheduled_tasks(
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE],
max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS,
):
# FAILED and COMPLETE tasks should never be running
assert task.id not in self._running_tasks
if (
self._clock.time_msec()
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
):
await self._store.delete_scheduled_task(task.id)
async def _launch_task(self, task: ScheduledTask) -> None:
@ -339,6 +376,9 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
# Try launch a new task since we've finished with this one.
self._clock.call_later(1, self._launch_scheduled_tasks)
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
@ -355,4 +395,4 @@ class TaskScheduler:
self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE)
run_as_background_process(task.action, wrapper)
run_as_background_process(f"task-{task.action}", wrapper)

View File

@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import Clock
@ -150,12 +150,12 @@ async def filter_events_for_client(
async def filter_event_for_clients_with_state(
store: DataStore,
user_ids: Collection[str],
user_ids: StrCollection,
event: EventBase,
context: EventContext,
is_peeking: bool = False,
filter_send_to_client: bool = True,
) -> Collection[str]:
) -> StrCollection:
"""
Checks to see if an event is visible to the users in the list at the time of
the event.

View File

@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
class FakeUserInfo:
is_guest = False
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
class FakeUserInfo:
is_guest = True
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"

View File

@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
return hs
@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
# Create a new login for the user and dehydrated the device
device_id, access_token, _expiration_time, _refresh_token = self.get_success(
device_id, access_token, _expiration_time, refresh_token = self.get_success(
self.registration.register_device(
user_id=user_id,
device_id=None,
initial_display_name="new device",
should_issue_refresh_token=True,
)
)
@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the user device has the refresh token
assert refresh_token is not None
self.get_success(
self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View File

@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
{
UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id,
"password_hash": self.pwhash,
"admin": 0,
"is_guest": 0,
"consent_version": None,
"consent_ts": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 0,
"user_type": None,
"deactivated": 0,
"locked": 0,
"shadow_banned": 0,
"approved": 1,
"last_seen_ts": None,
},
user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
consent_server_notice_sent=None,
consent_ts=None,
consent_version=None,
appservice_id=None,
creation_ts=0,
user_type=None,
is_deactivated=False,
locked=False,
is_shadow_banned=False,
approved=True,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user
self.assertEqual(user["consent_version"], "1")
self.assertGreater(user["consent_ts"], before_consent)
self.assertLess(user["consent_ts"], self.clock.time_msec())
self.assertEqual(user.consent_version, "1")
self.assertIsNotNone(user.consent_ts)
assert user.consent_ts is not None
self.assertGreater(user.consent_ts, before_consent)
self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertTrue(user["approved"])
self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertFalse(user["approved"])
self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
self.assertEqual(user["approved"], 1)
self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)