Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes
Erik Johnston 2023-09-14 16:21:58 +01:00
commit 1e0b96f1a4
43 changed files with 341 additions and 242 deletions

1
changelog.d/16288.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.

1
changelog.d/16301.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

1
changelog.d/16304.doc Normal file
View File

@ -0,0 +1 @@
Link to the Alpine Linux community package for Synapse.

1
changelog.d/16313.misc Normal file
View File

@ -0,0 +1 @@
Delete device messages asynchronously and in staged batches using the task scheduler.

1
changelog.d/16314.misc Normal file
View File

@ -0,0 +1 @@
Remove a reference cycle for in background processes.

1
changelog.d/16316.misc Normal file
View File

@ -0,0 +1 @@
Refactor `get_user_by_id`.

1
changelog.d/16318.misc Normal file
View File

@ -0,0 +1 @@
Speed up task to delete to-device messages.

View File

@ -155,6 +155,14 @@ sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt sudo pip install py-bcrypt
``` ```
#### Alpine Linux
6543 maintains [Synapse packages for Alpine Linux](https://pkgs.alpinelinux.org/packages?name=synapse&branch=edge) in the community repository. Install with:
```sh
sudo apk add synapse
```
#### Void Linux #### Void Linux
Synapse can be found in the void repositories as Synapse can be found in the void repositories as

View File

@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id) stored_user = await self.store.get_user_by_id(user_id)
if not stored_user: if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id) raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]: if not stored_user.is_guest:
raise InvalidClientTokenError( raise InvalidClientTokenError(
"Guest access token used for regular user" "Guest access token used for regular user"
) )

View File

@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname) user_id = UserID(username, self._hostname)
# First try to find a user from the username claim # First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string()) user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None: if user_info is None:
# If the user does not exist, we should create it on the fly # If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen # TODO: we could use SCIM to provision users ahead of time and listen

View File

@ -27,9 +27,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Iterable,
List, List,
NoReturn, NoReturn,
Optional, Optional,
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules, load_legacy_third_party_event_rules,
) )
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process from synapse.util.daemonize import daemonize_process
@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
""" """
Start Prometheus metrics server. Start Prometheus metrics server.
""" """
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
def listen_manhole( def listen_manhole(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
manhole_settings: ManholeConfig, manhole_settings: ManholeConfig,
manhole_globals: dict, manhole_globals: dict,
@ -339,7 +337,7 @@ def listen_manhole(
def listen_tcp( def listen_tcp(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
factory: ServerFactory, factory: ServerFactory,
reactor: IReactorTCP = reactor, reactor: IReactorTCP = reactor,
@ -448,7 +446,7 @@ def listen_http(
def listen_ssl( def listen_ssl(
bind_addresses: Collection[str], bind_addresses: StrCollection,
port: int, port: int,
factory: ServerFactory, factory: ServerFactory,
context_factory: IOpenSSLContextFactory, context_factory: IOpenSSLContextFactory,

View File

@ -26,7 +26,6 @@ from textwrap import dedent
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
Collection,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
@ -384,7 +383,7 @@ class RootConfig:
config_classes: List[Type[Config]] = [] config_classes: List[Type[Config]] = []
def __init__(self, config_files: Collection[str] = ()): def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize. # Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files] self.config_files = [os.path.abspath(path) for path in config_files]

View File

@ -25,7 +25,6 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]: def keys(self) -> Iterable[str]:
return self._dict.keys() return self._dict.keys()
def prev_event_ids(self) -> Sequence[str]: def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id return self._event_id
def prev_event_ids(self) -> Sequence[str]: def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr import attr
from signedjson.types import SigningKey from signedjson.types import SigningKey
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -103,7 +103,7 @@ class EventBuilder:
async def build( async def build(
self, self,
prev_event_ids: Collection[str], prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]], auth_event_ids: Optional[List[str]],
depth: Optional[int] = None, depth: Optional[int] = None,
) -> EventBase: ) -> EventBase:
@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2: if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids) auth_events = await self._store.add_event_hashes(auth_event_ids)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
from typing import Iterable, List, Type, Union, cast from typing import List, Type, Union, cast
import jsonschema import jsonschema
from pydantic import Field, StrictBool, StrictStr from pydantic import Field, StrictBool, StrictStr
@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
class EventValidator: class EventValidator:
@ -225,7 +225,7 @@ class EventValidator:
self._ensure_state_event(event) self._ensure_state_event(event)
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys: for s in keys:
if s not in d: if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,)) raise SynapseError(400, "'%s' not in content" % (s,))

View File

@ -102,7 +102,7 @@ class AccountHandler:
""" """
status = {"exists": False} status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None: if userinfo is not None:
status = { status = {

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership from synapse.api.constants import Direction, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,38 +57,30 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]: async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details""" """Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string()) user_info: Optional[UserInfo] = await self._store.get_user_by_id(
if user_info_dict is None: user.to_string()
)
if user_info is None:
return None return None
# Restrict returned information to a known set of fields. This prevents additional user_info_dict = {
# fields added to get_user_by_id from modifying Synapse's external API surface. "name": user.to_string(),
user_info_to_return = { "admin": user_info.is_admin,
"name", "deactivated": user_info.is_deactivated,
"admin", "locked": user_info.locked,
"deactivated", "shadow_banned": user_info.is_shadow_banned,
"locked", "creation_ts": user_info.creation_ts,
"shadow_banned", "appservice_id": user_info.appservice_id,
"creation_ts", "consent_server_notice_sent": user_info.consent_server_notice_sent,
"appservice_id", "consent_version": user_info.consent_version,
"consent_server_notice_sent", "consent_ts": user_info.consent_ts,
"consent_version", "user_type": user_info.user_type,
"consent_ts", "is_guest": user_info.is_guest,
"user_type",
"is_guest",
"last_seen_ts",
} }
if self._msc3866_enabled: if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled. # Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved") user_info_dict["approved"] = user_info.approved
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
# Add additional user metadata # Add additional user metadata
profile = await self._store.get_profileinfo(user) profile = await self._store.get_profileinfo(user)
@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:

View File

@ -388,7 +388,8 @@ class DeviceWorkerHandler:
"Trying handling device list state for partial join: not supported on workers." "Trying handling device list state for partial join: not supported on workers."
) )
DEVICE_MSGS_DELETE_BATCH_LIMIT = 100 DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
DEVICE_MSGS_DELETE_SLEEP_MS = 1000
async def _delete_device_messages( async def _delete_device_messages(
self, self,
@ -400,6 +401,8 @@ class DeviceWorkerHandler:
device_id = task.params["device_id"] device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"] up_to_stream_id = task.params["up_to_stream_id"]
# Delete the messages in batches to avoid too much DB load.
while True:
res = await self.store.delete_messages_for_device( res = await self.store.delete_messages_for_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
@ -409,10 +412,8 @@ class DeviceWorkerHandler:
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
return TaskStatus.COMPLETE, None, None return TaskStatus.COMPLETE, None, None
else:
# There is probably still device messages to be deleted, let's keep the task active and it will be run await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
return TaskStatus.ACTIVE, None, None
class DeviceHandler(DeviceWorkerHandler): class DeviceHandler(DeviceWorkerHandler):
@ -758,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID # If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access # matched the stored dehydrated device), then modify the access
# token to use the dehydrated device's ID and copy the old device # token and refresh token to use the dehydrated device's ID and
# display name to the dehydrated device, and destroy the old device # copy the old device display name to the dehydrated device,
# ID # and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token( old_device_id = await self.store.set_device_for_access_token(
access_token, device_id access_token, device_id
) )
await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id) old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None: if old_device is None:
raise errors.NotFoundError() raise errors.NotFoundError()

View File

@ -828,13 +828,13 @@ class EventCreationHandler:
u = await self.store.get_user_by_id(user_id) u = await self.store.get_user_by_id(user_id)
assert u is not None assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent # support and bot users are not required to consent
return return
if u["appservice_id"] is not None: if u.appservice_id is not None:
# users registered by an appservice are exempt # users registered by an appservice are exempt
return return
if u["consent_version"] == self.config.consent.user_consent_version: if u.consent_version == self.config.consent.user_consent_version:
return return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)

View File

@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that # the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes. # the entries can either be Lists or bytes.
RawHeaderValue = Union[ RawHeaderValue = Union[
List[str], StrSequence,
List[bytes], List[bytes],
List[Union[str, bytes]], List[Union[str, bytes]],
Tuple[str, ...],
Tuple[bytes, ...], Tuple[bytes, ...],
Tuple[Union[str, bytes], ...], Tuple[Union[str, bytes], ...],
] ]

View File

@ -18,7 +18,6 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Iterable,
List, List,
Mapping, Mapping,
Optional, Optional,
@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder from synapse.util import json_decoder
if TYPE_CHECKING: if TYPE_CHECKING:
@ -340,7 +339,7 @@ def parse_string(
name: str, name: str,
default: str, default: str,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -352,7 +351,7 @@ def parse_string(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -365,7 +364,7 @@ def parse_string(
*, *,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -376,7 +375,7 @@ def parse_string(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
""" """
@ -485,7 +484,7 @@ def parse_enum(
def _parse_string_value( def _parse_string_value(
value: bytes, value: bytes,
allowed_values: Optional[Iterable[str]], allowed_values: Optional[StrCollection],
name: str, name: str,
encoding: str, encoding: str,
) -> str: ) -> str:
@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]], args: Mapping[bytes, Sequence[bytes]],
name: str, name: str,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
... ...
@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str, name: str,
default: List[str], default: List[str],
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]:
... ...
@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]:
... ...
@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None, default: Optional[List[str]] = None,
*, *,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
... ...
@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str, name: str,
default: Optional[List[str]] = None, default: Optional[List[str]] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]:
""" """
@ -610,7 +609,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
*, *,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None, default: Optional[str] = None,
*, *,
required: Literal[True], required: Literal[True],
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str:
... ...
@ -635,7 +634,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
... ...
@ -646,7 +645,7 @@ def parse_string_from_args(
name: str, name: str,
default: Optional[str] = None, default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]:
""" """
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type) return validate_json_object(content, model_type)
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = [] absent = []
for k in required: for k in required:
if k not in body: if k not in body:

View File

@ -25,7 +25,6 @@ from typing import (
Iterable, Iterable,
Mapping, Mapping,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector from synapse.metrics._types import Collector
from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,7 +81,7 @@ class LaterGauge(Collector):
name: str name: str
desc: str desc: str
labels: Optional[Sequence[str]] = attr.ib(hash=False) labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric), # callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value # or dict mapping from a label tuple to a value
caller: Callable[ caller: Callable[
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self, self,
name: str, name: str,
desc: str, desc: str,
labels: Sequence[str], labels: StrSequence,
sub_metrics: Sequence[str], sub_metrics: StrSequence,
): ):
self.name = name self.name = name
self.desc = desc self.desc = desc

View File

@ -322,13 +322,21 @@ class BackgroundProcessLoggingContext(LoggingContext):
if instance_id is None: if instance_id is None:
instance_id = id(self) instance_id = id(self)
super().__init__("%s-%s" % (name, instance_id)) super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self) self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource.struct_rusage]") -> None: def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again).""" """Log context has started running (again)."""
super().start(rusage) super().start(rusage)
if self._proc is None:
logger.error(
"Background process re-entered without a proc: %s",
self.name,
stack_info=True,
)
return
# We've become active again so we make sure we're in the list of active # We've become active again so we make sure we're in the list of active
# procs. (Note that "start" here means we've become active, as opposed # procs. (Note that "start" here means we've become active, as opposed
# to starting for the first time.) # to starting for the first time.)
@ -345,6 +353,14 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__exit__(type, value, traceback) super().__exit__(type, value, traceback)
if self._proc is None:
logger.error(
"Background process exited without a proc: %s",
self.name,
stack_info=True,
)
return
# The background process has finished. We explicitly remove and manually # The background process has finished. We explicitly remove and manually
# update the metrics here so that if nothing is scraping metrics the set # update the metrics here so that if nothing is scraping metrics the set
# doesn't infinitely grow. # doesn't infinitely grow.
@ -352,3 +368,6 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.discard(self._proc) _background_processes_active_since_last_scrape.discard(self._proc)
self._proc.update_metrics() self._proc.update_metrics()
# Set proc to None to break the reference cycle.
self._proc = None

View File

@ -572,7 +572,7 @@ class ModuleApi:
Returns: Returns:
UserInfo object if a user was found, otherwise None UserInfo object if a user was found, otherwise None
""" """
return await self._store.get_userinfo_by_id(user_id) return await self._store.get_user_by_id(user_id)
async def get_user_by_req( async def get_user_by_req(
self, self,
@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
# Ensure the user exists, so we don't just write to users that aren't there. # Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None: if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.") raise ValueError(f"User {user_id} does not exist on this server.")
await self._handler.add_account_data_for_user(user_id, data_type, new_data) await self._handler.add_account_data_for_user(user_id, data_type, new_data)

View File

@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__( def __init__(
self, self,
user_id: str, user_id: str,
rooms: Collection[str], rooms: StrCollection,
current_token: StreamToken, current_token: StreamToken,
time_now_ms: int, time_now_ms: int,
): ):
@ -457,7 +457,7 @@ class Notifier:
stream_key: str, stream_key: str,
new_token: Union[int, RoomStreamToken], new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[Collection[str]] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None:
"""Used to inform listeners that something has happened event wise. """Used to inform listeners that something has happened event wise.
@ -529,7 +529,7 @@ class Notifier:
user_id: str, user_id: str,
timeout: int, timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]], callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids: Optional[Collection[str]] = None, room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START, from_token: StreamToken = StreamToken.START,
) -> T: ) -> T:
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the

View File

@ -672,14 +672,12 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key cmd.instance_name, cmd.lock_name, cmd.lock_key
) )
async def on_NEW_ACTIVE_TASK( def on_NEW_ACTIVE_TASK(
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
) -> None: ) -> None:
"""Called when get a new NEW_ACTIVE_TASK command.""" """Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler: if self._task_scheduler:
task = await self._task_scheduler.get_task(cmd.data) self._task_scheduler.launch_task_by_id(cmd.data)
if task:
await self._task_scheduler._launch_task(task)
def new_connection(self, connection: IReplicationConnection) -> None: def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection.""" """Called when we have a new connection."""

View File

@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.types import JsonDict from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_patterns( def client_patterns(
path_regex: str, path_regex: str,
releases: Iterable[str] = ("r0", "v3"), releases: StrCollection = ("r0", "v3"),
unstable: bool = True, unstable: bool = True,
v1: bool = False, v1: bool = False,
) -> Iterable[Pattern]: ) -> Iterable[Pattern]:

View File

@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None: if u is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
has_consented = u["consent_version"] == version has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii") userhmac = userhmac_bytes.decode("ascii")
try: try:

View File

@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None: if u is None:
return return
if u["is_guest"] and not self._send_to_guests: if u.is_guest and not self._send_to_guests:
# don't send to guests # don't send to guests
return return
if u["consent_version"] == self._current_consent_version: if u.consent_version == self._current_consent_version:
# user has already consented # user has already consented
return return
if u["consent_server_notice_sent"] == self._current_consent_version: if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user # we've already sent a notice to the user
return return

View File

@ -20,7 +20,6 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
DefaultDict, DefaultDict,
Dict, Dict,
FrozenSet, FrozenSet,
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events( async def compute_state_after_events(
self, self,
room_id: str, room_id: str,
event_ids: Collection[str], event_ids: StrCollection,
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
await_full_state: bool = True, await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter) return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room( async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: Collection[str] self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]: ) -> Set[str]:
""" """
Get the users IDs who are currently in a room. Get the users IDs who are currently in a room.
@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state) return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]: ) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids
@ -470,7 +469,7 @@ class StateHandler:
@trace @trace
@measure_func() @measure_func()
async def resolve_state_groups_for_events( async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry: ) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each """Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore" store: "DataStore"
def get_events( def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]: ) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database """Get events from the database

View File

@ -17,7 +17,6 @@ import logging
from typing import ( from typing import (
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion, room_version: RoomVersion,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Args: Args:

View File

@ -19,7 +19,6 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Collection,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a # This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests. # TestStateResolutionStore in tests.
def get_events( def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]: ) -> Awaitable[Dict[str, EventBase]]:
... ...
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
auth_difference_unpersisted_part: Collection[str] = union - intersection auth_difference_unpersisted_part: StrCollection = union - intersection
else: else:
auth_difference_unpersisted_part = () auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets] state_sets_ids = [set(state_set.values()) for state_set in state_sets]

View File

@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
} }
return list(results.values()) return list(results.values())
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""
return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)

View File

@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},

View File

@ -16,7 +16,7 @@
import logging import logging
import random import random
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr import attr
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
@cached() @cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Deprecated: use get_userinfo_by_id instead""" """Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform # We could technically use simple_select_one here, but it would not perform
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute( txn.execute(
""" """
SELECT SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts, name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type, consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved, COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts COALESCE(locked, FALSE) AS locked
FROM users FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ? WHERE name = ?
""", """,
(user_id,), (user_id,),
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id", desc="get_user_by_id",
func=get_user_by_id_txn, func=get_user_by_id_txn,
) )
if row is None:
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
return None return None
return UserInfo( return UserInfo(
appservice_id=user_data["appservice_id"], appservice_id=row["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"], consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=user_data["consent_version"], consent_version=row["consent_version"],
creation_ts=user_data["creation_ts"], consent_ts=row["consent_ts"],
is_admin=bool(user_data["admin"]), creation_ts=row["creation_ts"],
is_deactivated=bool(user_data["deactivated"]), is_admin=bool(row["admin"]),
is_guest=bool(user_data["is_guest"]), is_deactivated=bool(row["deactivated"]),
is_shadow_banned=bool(user_data["shadow_banned"]), is_guest=bool(row["is_guest"]),
user_id=UserID.from_string(user_data["name"]), is_shadow_banned=bool(row["shadow_banned"]),
user_type=user_data["user_type"], user_id=UserID.from_string(row["name"]),
last_seen_ts=user_data["last_seen_ts"], user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
) )
async def is_trial_user(self, user_id: str) -> bool: async def is_trial_user(self, user_id: str) -> bool:
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec() now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get( days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days info.appservice_id, self.config.server.mau_trial_days
) )
trial_duration_ms = days * 24 * 60 * 60 * 1000 trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial return is_trial
@cached() @cached()
@ -2312,6 +2280,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id return next_id
async def set_device_for_refresh_token(
self, user_id: str, old_device_id: str, device_id: str
) -> None:
"""Moves refresh tokens from old device to current device
Args:
user_id: The user of the devices.
old_device_id: The old device.
device_id: The new device ID.
Returns:
None
"""
await self.db_pool.simple_update(
"refresh_tokens",
keyvalues={"user_id": user_id, "device_id": old_device_id},
updatevalues={"device_id": device_id},
desc="set_device_for_refresh_token",
)
def _set_device_for_access_token_txn( def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str self, txn: LoggingTransaction, token: str, device_id: str
) -> str: ) -> str:

View File

@ -53,6 +53,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None, statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None, max_timestamp: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ScheduledTask]: ) -> List[ScheduledTask]:
"""Get a list of scheduled tasks from the DB. """Get a list of scheduled tasks from the DB.
@ -62,6 +63,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
statuses: Limit the returned tasks to the specific statuses statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one a timestamp inferior to the specified one
limit: Only return `limit` number of rows if set.
Returns: a list of `ScheduledTask`, ordered by increasing timestamps Returns: a list of `ScheduledTask`, ordered by increasing timestamps
""" """
@ -94,6 +96,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
sql = sql + " ORDER BY timestamp" sql = sql + " ORDER BY timestamp"
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)

View File

@ -0,0 +1,16 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp);

View File

@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
@attr.s(auto_attribs=True, frozen=True, slots=True) @attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo: class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id. """Holds information about a user. Result of get_user_by_id.
Attributes: Attributes:
user_id: ID of the user. user_id: ID of the user.
appservice_id: Application service ID that created this user. appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent. consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to. consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user. creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin. is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated. is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user. is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned. is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options). user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user. approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
""" """
user_id: UserID user_id: UserID
appservice_id: Optional[int] appservice_id: Optional[int]
consent_server_notice_sent: Optional[str] consent_server_notice_sent: Optional[str]
consent_version: Optional[str] consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str] user_type: Optional[str]
creation_ts: int creation_ts: int
is_admin: bool is_admin: bool
is_deactivated: bool is_deactivated: bool
is_guest: bool is_guest: bool
is_shadow_banned: bool is_shadow_banned: bool
last_seen_ts: Optional[int] approved: bool
locked: bool
class UserProfile(TypedDict): class UserProfile(TypedDict):

View File

@ -15,12 +15,14 @@
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
from prometheus_client import Gauge
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.logging.context import nested_logging_context from synapse.logging.context import nested_logging_context
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import JsonMapping, ScheduledTask, TaskStatus from synapse.types import JsonMapping, ScheduledTask, TaskStatus
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -30,12 +32,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
running_tasks_gauge = Gauge(
"synapse_scheduler_running_tasks",
"The number of concurrent running tasks handled by the TaskScheduler",
)
class TaskScheduler: class TaskScheduler:
""" """
This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background` This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
@ -70,6 +66,8 @@ class TaskScheduler:
# Precision of the scheduler, evaluation of tasks to run will only happen # Precision of the scheduler, evaluation of tasks to run will only happen
# every `SCHEDULE_INTERVAL_MS` ms # every `SCHEDULE_INTERVAL_MS` ms
SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn
# How often to clean up old tasks.
CLEANUP_INTERVAL_MS = 30 * 60 * 1000
# Time before a complete or failed task is deleted from the DB # Time before a complete or failed task is deleted from the DB
KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
# Maximum number of tasks that can run at the same time # Maximum number of tasks that can run at the same time
@ -92,12 +90,24 @@ class TaskScheduler:
] = {} ] = {}
self._run_background_tasks = hs.config.worker.run_background_tasks self._run_background_tasks = hs.config.worker.run_background_tasks
# Flag to make sure we only try and launch new tasks once at a time.
self._launching_new_tasks = False
if self._run_background_tasks: if self._run_background_tasks:
self._clock.looping_call( self._clock.looping_call(
run_as_background_process, self._launch_scheduled_tasks,
TaskScheduler.SCHEDULE_INTERVAL_MS, TaskScheduler.SCHEDULE_INTERVAL_MS,
"handle_scheduled_tasks", )
self._handle_scheduled_tasks, self._clock.looping_call(
self._clean_scheduled_tasks,
TaskScheduler.SCHEDULE_INTERVAL_MS,
)
LaterGauge(
"synapse_scheduler_running_tasks",
"The number of concurrent running tasks handled by the TaskScheduler",
labels=None,
caller=lambda: len(self._running_tasks),
) )
def register_action( def register_action(
@ -234,6 +244,7 @@ class TaskScheduler:
resource_id: Optional[str] = None, resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None, statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None, max_timestamp: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ScheduledTask]: ) -> List[ScheduledTask]:
"""Get a list of tasks. Returns all the tasks if no args is provided. """Get a list of tasks. Returns all the tasks if no args is provided.
@ -247,6 +258,7 @@ class TaskScheduler:
statuses: Limit the returned tasks to the specific statuses statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one a timestamp inferior to the specified one
limit: Only return `limit` number of rows if set.
Returns Returns
A list of `ScheduledTask`, ordered by increasing timestamps A list of `ScheduledTask`, ordered by increasing timestamps
@ -256,6 +268,7 @@ class TaskScheduler:
resource_id=resource_id, resource_id=resource_id,
statuses=statuses, statuses=statuses,
max_timestamp=max_timestamp, max_timestamp=max_timestamp,
limit=limit,
) )
async def delete_task(self, id: str) -> None: async def delete_task(self, id: str) -> None:
@ -273,33 +286,57 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id) await self._store.delete_scheduled_task(id)
async def _handle_scheduled_tasks(self) -> None: def launch_task_by_id(self, id: str) -> None:
"""Main loop taking care of launching tasks and cleaning up old ones.""" """Try launching the task with the given ID."""
await self._launch_scheduled_tasks() # Don't bother trying to launch new tasks if we're already at capacity.
await self._clean_scheduled_tasks() if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
async def _launch_task_by_id(self, id: str) -> None:
"""Helper async function for `launch_task_by_id`."""
task = await self.get_task(id)
if task:
await self._launch_task(task)
@wrap_as_background_process("launch_scheduled_tasks")
async def _launch_scheduled_tasks(self) -> None: async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time.""" """Retrieve and launch scheduled tasks that should be running at that time."""
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): # Don't bother trying to launch new tasks if we're already at capacity.
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
if self._launching_new_tasks:
return
self._launching_new_tasks = True
try:
for task in await self.get_tasks(
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
):
await self._launch_task(task) await self._launch_task(task)
for task in await self.get_tasks( for task in await self.get_tasks(
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() statuses=[TaskStatus.SCHEDULED],
max_timestamp=self._clock.time_msec(),
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
): ):
await self._launch_task(task) await self._launch_task(task)
running_tasks_gauge.set(len(self._running_tasks)) finally:
self._launching_new_tasks = False
@wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None: async def _clean_scheduled_tasks(self) -> None:
"""Clean old complete or failed jobs to avoid clutter the DB.""" """Clean old complete or failed jobs to avoid clutter the DB."""
now = self._clock.time_msec()
for task in await self._store.get_scheduled_tasks( for task in await self._store.get_scheduled_tasks(
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE],
max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS,
): ):
# FAILED and COMPLETE tasks should never be running # FAILED and COMPLETE tasks should never be running
assert task.id not in self._running_tasks assert task.id not in self._running_tasks
if (
self._clock.time_msec()
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
):
await self._store.delete_scheduled_task(task.id) await self._store.delete_scheduled_task(task.id)
async def _launch_task(self, task: ScheduledTask) -> None: async def _launch_task(self, task: ScheduledTask) -> None:
@ -339,6 +376,9 @@ class TaskScheduler:
) )
self._running_tasks.remove(task.id) self._running_tasks.remove(task.id)
# Try launch a new task since we've finished with this one.
self._clock.call_later(1, self._launch_scheduled_tasks)
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return return
@ -355,4 +395,4 @@ class TaskScheduler:
self._running_tasks.add(task.id) self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE) await self.update_task(task.id, status=TaskStatus.ACTIVE)
run_as_background_process(task.action, wrapper) run_as_background_process(f"task-{task.action}", wrapper)

View File

@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import Clock from synapse.util import Clock
@ -150,12 +150,12 @@ async def filter_events_for_client(
async def filter_event_for_clients_with_state( async def filter_event_for_clients_with_state(
store: DataStore, store: DataStore,
user_ids: Collection[str], user_ids: StrCollection,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
is_peeking: bool = False, is_peeking: bool = False,
filter_send_to_client: bool = True, filter_send_to_client: bool = True,
) -> Collection[str]: ) -> StrCollection:
""" """
Checks to see if an event is visible to the users in the list at the time of Checks to see if an event is visible to the users in the list at the time of
the event. the event.

View File

@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) class FakeUserInfo:
is_guest = False
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
def test_get_guest_user_from_macaroon(self) -> None: def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) class FakeUserInfo:
is_guest = True
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"

View File

@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.message_handler = hs.get_device_message_handler() self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler() self.registration = hs.get_registration_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
# Create a new login for the user and dehydrated the device # Create a new login for the user and dehydrated the device
device_id, access_token, _expiration_time, _refresh_token = self.get_success( device_id, access_token, _expiration_time, refresh_token = self.get_success(
self.registration.register_device( self.registration.register_device(
user_id=user_id, user_id=user_id,
device_id=None, device_id=None,
initial_display_name="new device", initial_display_name="new device",
should_issue_refresh_token=True,
) )
) )
@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(user_info.device_id, retrieved_device_id) self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the user device has the refresh token
assert refresh_token is not None
self.get_success(
self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
)
# make sure the device has the display name that was set from the login # make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View File

@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual( self.assertEqual(
{ UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name' # TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id, user_id=UserID.from_string(self.user_id),
"password_hash": self.pwhash, is_admin=False,
"admin": 0, is_guest=False,
"is_guest": 0, consent_server_notice_sent=None,
"consent_version": None, consent_ts=None,
"consent_ts": None, consent_version=None,
"consent_server_notice_sent": None, appservice_id=None,
"appservice_id": None, creation_ts=0,
"creation_ts": 0, user_type=None,
"user_type": None, is_deactivated=False,
"deactivated": 0, locked=False,
"locked": 0, is_shadow_banned=False,
"shadow_banned": 0, approved=True,
"approved": 1, ),
"last_seen_ts": None,
},
(self.get_success(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user assert user
self.assertEqual(user["consent_version"], "1") self.assertEqual(user.consent_version, "1")
self.assertGreater(user["consent_ts"], before_consent) self.assertIsNotNone(user.consent_ts)
self.assertLess(user["consent_ts"], self.clock.time_msec()) assert user.consent_ts is not None
self.assertGreater(user.consent_ts, before_consent)
self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None: def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None assert user is not None
self.assertTrue(user["approved"]) self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved) self.assertTrue(approved)
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None assert user is not None
self.assertFalse(user["approved"]) self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved) self.assertFalse(approved)
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user) self.assertIsNotNone(user)
assert user is not None assert user is not None
self.assertEqual(user["approved"], 1) self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved) self.assertTrue(approved)