Merge branch 'release-v1.43' of github.com:matrix-org/synapse into matrix-org-hotfixes

anoa/log_11772
Andrew Morgan 2021-09-14 11:02:37 +01:00
commit 003c2ab629
119 changed files with 743 additions and 452 deletions

View File

@ -1,10 +1,2 @@
# This file serves as a blacklist for SyTest tests that we expect will fail in # This file serves as a blacklist for SyTest tests that we expect will fail in
# Synapse when run under worker mode. For more details, see sytest-blacklist. # Synapse when run under worker mode. For more details, see sytest-blacklist.
Can re-join room if re-invited
# new failures as of https://github.com/matrix-org/sytest/pull/732
Device list doesn't change if remote server is down
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update

1
changelog.d/10601.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to the synapse.util package.

1
changelog.d/10774.bugfix Normal file
View File

@ -0,0 +1 @@
Properly handle room upgrades of spaces.

1
changelog.d/10788.misc Normal file
View File

@ -0,0 +1 @@
Remove fixed and flakey tests from the sytest-blacklist.

1
changelog.d/10789.misc Normal file
View File

@ -0,0 +1 @@
Improve internal details of the user directory code.

1
changelog.d/10795.doc Normal file
View File

@ -0,0 +1 @@
Correct 2 typographical errors in the *Log Contexts* documentation.

1
changelog.d/10798.misc Normal file
View File

@ -0,0 +1 @@
Use direct references to config flags.

1
changelog.d/10799.misc Normal file
View File

@ -0,0 +1 @@
Add a max version for the `jaeger-client` dependency for an incompatibility with the rust reporter.

1
changelog.d/10804.doc Normal file
View File

@ -0,0 +1 @@
Fixed a wording mistake in the sample configuration. Contributed by @bramvdnheuvel:nltrix.net.

View File

@ -10,7 +10,7 @@ Logcontexts are also used for CPU and database accounting, so that we
can track which requests were responsible for high CPU use or database can track which requests were responsible for high CPU use or database
activity. activity.
The `synapse.logging.context` module provides a facilities for managing The `synapse.logging.context` module provides facilities for managing
the current log context (as well as providing the `LoggingContextFilter` the current log context (as well as providing the `LoggingContextFilter`
class). class).
@ -351,7 +351,7 @@ and the awaitable chain is now orphaned, and will be garbage-collected at
some point. Note that `await_something_interesting` is a coroutine, some point. Note that `await_something_interesting` is a coroutine,
which Python implements as a generator function. When Python which Python implements as a generator function. When Python
garbage-collects generator functions, it gives them a chance to garbage-collects generator functions, it gives them a chance to
clean up by making the `async` (or `yield`) raise a `GeneratorExit` clean up by making the `await` (or `yield`) raise a `GeneratorExit`
exception. In our case, that means that the `__exit__` handler of exception. In our case, that means that the `__exit__` handler of
`PreserveLoggingContext` will carefully restore the request context, but `PreserveLoggingContext` will carefully restore the request context, but
there is now nothing waiting for its return, so the request context is there is now nothing waiting for its return, so the request context is

View File

@ -2086,7 +2086,7 @@ password_config:
# #
#require_lowercase: true #require_lowercase: true
# Whether a password must contain at least one lowercase letter. # Whether a password must contain at least one uppercase letter.
# Defaults to 'false'. # Defaults to 'false'.
# #
#require_uppercase: true #require_uppercase: true

View File

@ -10,3 +10,40 @@ DB corruption) get stale or out of sync. If this happens, for now the
solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql) solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql)
and then restart synapse. This should then start a background task to and then restart synapse. This should then start a background task to
flush the current tables and regenerate the directory. flush the current tables and regenerate the directory.
Data model
----------
There are five relevant tables that collectively form the "user directory".
Three of them track a master list of all the users we could search for.
The last two (collectively called the "search tables") track who can
see who.
From all of these tables we exclude three types of local user:
- support users
- appservice users
- deactivated users
* `user_directory`. This contains the user_id, display name and avatar we'll
return when you search the directory.
- Because there's only one directory entry per user, it's important that we only
ever put publicly visible names here. Otherwise we might leak a private
nickname or avatar used in a private room.
- Indexed on rooms. Indexed on users.
* `user_directory_search`. To be joined to `user_directory`. It contains an extra
column that enables full text search based on user ids and display names.
Different schemas for SQLite and Postgres with different code paths to match.
- Indexed on the full text search data. Indexed on users.
* `user_directory_stream_pos`. When the initial background update to populate
the directory is complete, we record a stream position here. This indicates
that synapse should now listen for room changes and incrementally update
the directory where necessary.
* `users_in_public_rooms`. Contains associations between users and the public rooms they're in.
Used to determine which users are in public rooms and should be publicly visible in the directory.
* `users_who_share_private_rooms`. Rows are triples `(L, M, room id)` where `L`
is a local user and `M` is a local or remote user. `L` and `M` should be
different, but this isn't enforced by a constraint.

View File

@ -74,17 +74,7 @@ files =
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py, synapse/util,
synapse/util/caches,
synapse/util/daemonize.py,
synapse/util/hash.py,
synapse/util/iterutils.py,
synapse/util/linked_list.py,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/module_loader.py,
synapse/util/msisdn.py,
synapse/util/stringutils.py,
synapse/visibility.py, synapse/visibility.py,
tests/replication, tests/replication,
tests/test_event_auth.py, tests/test_event_auth.py,
@ -102,6 +92,69 @@ files =
[mypy-synapse.rest.client.*] [mypy-synapse.rest.client.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-pymacaroons.*] [mypy-pymacaroons.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
def buildProtocol(self, addr) -> RedisProtocol: ... def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory): class SubscriberFactory(RedisFactory):
def __init__(self): ... def __init__(self) -> None: ...

View File

@ -46,7 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * The point in time # * The point in time
# * The rate_hz of this particular entry. This can vary per request # * The rate_hz of this particular entry. This can vary per request
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
async def can_do_action( async def can_do_action(
self, self,
@ -56,7 +56,7 @@ class Ratelimiter:
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[float] = None,
) -> Tuple[bool, float]: ) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
@ -160,7 +160,7 @@ class Ratelimiter:
return allowed, time_allowed return allowed, time_allowed
def _prune_message_counts(self, time_now_s: int): def _prune_message_counts(self, time_now_s: float):
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
rate_hz limit rate_hz limit
@ -188,7 +188,7 @@ class Ratelimiter:
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[float] = None,
): ):
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError

View File

@ -41,11 +41,11 @@ class ConsentURIBuilder:
""" """
if hs_config.form_secret is None: if hs_config.form_secret is None:
raise ConfigError("form_secret not set in config") raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None: if hs_config.server.public_baseurl is None:
raise ConfigError("public_baseurl not set in config") raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8") self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl self._public_baseurl = hs_config.server.public_baseurl
def build_user_consent_uri(self, user_id): def build_user_consent_uri(self, user_id):
"""Build a URI which we can give to the user to do their privacy """Build a URI which we can give to the user to do their privacy

View File

@ -82,7 +82,7 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
run_command (Callable[]): callable that actually runs the reactor run_command (Callable[]): callable that actually runs the reactor
""" """
logger = logging.getLogger(config.worker_app) logger = logging.getLogger(config.worker.worker_app)
start_reactor( start_reactor(
appname, appname,
@ -398,7 +398,7 @@ async def start(hs: "HomeServer"):
# If background tasks are running on the main process, start collecting the # If background tasks are running on the main process, start collecting the
# phone home stats. # phone home stats.
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
start_phone_stats_home(hs) start_phone_stats_home(hs)
# We now freeze all allocated objects in the hopes that (almost) # We now freeze all allocated objects in the hopes that (almost)
@ -433,9 +433,13 @@ def setup_sentry(hs):
# We set some default tags that give some context to this instance # We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope: with sentry_sdk.configure_scope() as scope:
scope.set_tag("matrix_server_name", hs.config.server_name) scope.set_tag("matrix_server_name", hs.config.server.server_name)
app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" app = (
hs.config.worker.worker_app
if hs.config.worker.worker_app
else "synapse.app.homeserver"
)
name = hs.get_instance_name() name = hs.get_instance_name()
scope.set_tag("worker_app", app) scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name) scope.set_tag("worker_name", name)

View File

@ -178,12 +178,12 @@ def start(config_options):
sys.stderr.write("\n" + str(e) + "\n") sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1) sys.exit(1)
if config.worker_app is not None: if config.worker.worker_app is not None:
assert config.worker_app == "synapse.app.admin_cmd" assert config.worker.worker_app == "synapse.app.admin_cmd"
# Update the config with some basic overrides so that don't have to specify # Update the config with some basic overrides so that don't have to specify
# a full worker config. # a full worker config.
config.worker_app = "synapse.app.admin_cmd" config.worker.worker_app = "synapse.app.admin_cmd"
if ( if (
not config.worker_daemonize not config.worker_daemonize
@ -196,7 +196,7 @@ def start(config_options):
# Explicitly disable background processes # Explicitly disable background processes
config.update_user_directory = False config.update_user_directory = False
config.run_background_tasks = False config.worker.run_background_tasks = False
config.start_pushers = False config.start_pushers = False
config.pusher_shard_config.instances = [] config.pusher_shard_config.instances = []
config.send_federation = False config.send_federation = False
@ -205,7 +205,7 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
ss = AdminCmdServer( ss = AdminCmdServer(
config.server_name, config.server.server_name,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
) )

View File

@ -416,7 +416,7 @@ def start(config_options):
sys.exit(1) sys.exit(1)
# For backwards compatibility let any of the old app names. # For backwards compatibility let any of the old app names.
assert config.worker_app in ( assert config.worker.worker_app in (
"synapse.app.appservice", "synapse.app.appservice",
"synapse.app.client_reader", "synapse.app.client_reader",
"synapse.app.event_creator", "synapse.app.event_creator",
@ -430,7 +430,7 @@ def start(config_options):
"synapse.app.user_dir", "synapse.app.user_dir",
) )
if config.worker_app == "synapse.app.appservice": if config.worker.worker_app == "synapse.app.appservice":
if config.appservice.notify_appservices: if config.appservice.notify_appservices:
sys.stderr.write( sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process" "\nThe appservices must be disabled in the main synapse process"
@ -446,7 +446,7 @@ def start(config_options):
# For other worker types we force this to off. # For other worker types we force this to off.
config.appservice.notify_appservices = False config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.user_dir": if config.worker.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory: if config.server.update_user_directory:
sys.stderr.write( sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process" "\nThe update_user_directory must be disabled in the main synapse process"
@ -469,7 +469,7 @@ def start(config_options):
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = GenericWorkerServer( hs = GenericWorkerServer(
config.server_name, config.server.server_name,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
) )

View File

@ -350,7 +350,7 @@ def setup(config_options):
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server.server_name,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
) )

View File

@ -73,7 +73,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
store = hs.get_datastore() store = hs.get_datastore()
stats["homeserver"] = hs.config.server_name stats["homeserver"] = hs.config.server.server_name
stats["server_context"] = hs.config.server_context stats["server_context"] = hs.config.server_context
stats["timestamp"] = now stats["timestamp"] = now
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime

View File

@ -88,7 +88,7 @@ class AuthConfig(Config):
# #
#require_lowercase: true #require_lowercase: true
# Whether a password must contain at least one lowercase letter. # Whether a password must contain at least one uppercase letter.
# Defaults to 'false'. # Defaults to 'false'.
# #
#require_uppercase: true #require_uppercase: true

View File

@ -223,7 +223,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# writes. # writes.
log_context_filter = LoggingContextFilter() log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name}) log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory() old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs): def factory(*args, **kwargs):
@ -335,5 +335,5 @@ def setup_logging(
# Log immediately so we can grep backwards. # Log immediately so we can grep backwards.
logging.warning("***** STARTING SERVER *****") logging.warning("***** STARTING SERVER *****")
logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name) logging.info("Server hostname: %s", config.server.server_name)
logging.info("Instance name: %s", hs.get_instance_name()) logging.info("Instance name: %s", hs.get_instance_name())

View File

@ -14,6 +14,8 @@
from typing import Dict, Optional from typing import Dict, Optional
import attr
from ._base import Config from ._base import Config
@ -29,18 +31,13 @@ class RateLimitConfig:
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
@attr.s(auto_attribs=True)
class FederationRateLimitConfig: class FederationRateLimitConfig:
_items_and_default = { window_size: int = 1000
"window_size": 1000, sleep_limit: int = 10
"sleep_limit": 10, sleep_delay: int = 500
"sleep_delay": 500, reject_limit: int = 50
"reject_limit": 50, concurrent: int = 3
"concurrent": 3,
}
def __init__(self, **kwargs):
for i in self._items_and_default.keys():
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
class RatelimitConfig(Config): class RatelimitConfig(Config):
@ -69,11 +66,15 @@ class RatelimitConfig(Config):
else: else:
self.rc_federation = FederationRateLimitConfig( self.rc_federation = FederationRateLimitConfig(
**{ **{
"window_size": config.get("federation_rc_window_size"), k: v
"sleep_limit": config.get("federation_rc_sleep_limit"), for k, v in {
"sleep_delay": config.get("federation_rc_sleep_delay"), "window_size": config.get("federation_rc_window_size"),
"reject_limit": config.get("federation_rc_reject_limit"), "sleep_limit": config.get("federation_rc_sleep_limit"),
"concurrent": config.get("federation_rc_concurrent"), "sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"),
}.items()
if v is not None
} }
) )

View File

@ -88,7 +88,7 @@ class EventValidator:
self._validate_retention(event) self._validate_retention(event)
if event.type == EventTypes.ServerACL: if event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server_name, event): if not server_matches_acl_event(config.server.server_name, event):
raise SynapseError( raise SynapseError(
400, "Can't create an ACL event that denies the local server" 400, "Can't create an ACL event that denies the local server"
) )

View File

@ -22,6 +22,7 @@ from prometheus_client import Counter
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics import synapse.metrics
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
@ -280,11 +281,14 @@ class FederationSender(AbstractFederationSender):
self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
self._rr_txn_interval_per_room_ms = ( self._rr_txn_interval_per_room_ms = (
1000.0 / hs.config.federation_rr_transactions_per_room_per_second 1000.0
/ hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
) )
# wake up destinations that have outstanding PDUs to be caught up # wake up destinations that have outstanding PDUs to be caught up
self._catchup_after_startup_timer = self.clock.call_later( self._catchup_after_startup_timer: Optional[
IDelayedCall
] = self.clock.call_later(
CATCH_UP_STARTUP_DELAY_SEC, CATCH_UP_STARTUP_DELAY_SEC,
run_as_background_process, run_as_background_process,
"wake_destinations_needing_catchup", "wake_destinations_needing_catchup",
@ -406,7 +410,7 @@ class FederationSender(AbstractFederationSender):
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels( synapse.metrics.event_processing_lag_by_event.labels(
"federation_sender" "federation_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
@ -435,6 +439,7 @@ class FederationSender(AbstractFederationSender):
if events: if events:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id) ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"federation_sender" "federation_sender"

View File

@ -144,7 +144,7 @@ class GroupAttestionRenewer:
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing() self.attestations = hs.get_groups_attestation_signing()
if not hs.config.worker_app: if not hs.config.worker.worker_app:
self._renew_attestations_loop = self.clock.looping_call( self._renew_attestations_loop = self.clock.looping_call(
self._start_renew_attestations, 30 * 60 * 1000 self._start_renew_attestations, 30 * 60 * 1000
) )

View File

@ -45,16 +45,16 @@ class BaseHandler:
self.request_ratelimiter = Ratelimiter( self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0 store=self.store, clock=self.clock, rate_hz=0, burst_count=0
) )
self._rc_message = self.hs.config.rc_message self._rc_message = self.hs.config.ratelimiting.rc_message
# Check whether ratelimiting room admin message redaction is enabled # Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config # by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction: if self.hs.config.ratelimiting.rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second, rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count, burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count,
) )
else: else:
self.admin_redaction_ratelimiter = None self.admin_redaction_ratelimiter = None

View File

@ -78,7 +78,7 @@ class AccountValidityHandler:
) )
# Check the renewal emails to send and send them every 30min. # Check the renewal emails to send and send them every 30min.
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
@ -249,7 +249,7 @@ class AccountValidityHandler:
renewal_token = await self._get_renewal_token(user_id) renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl, self.hs.config.server.public_baseurl,
renewal_token, renewal_token,
) )
@ -398,6 +398,7 @@ class AccountValidityHandler:
""" """
now = self.clock.time_msec() now = self.clock.time_msec()
if expiration_ts is None: if expiration_ts is None:
assert self._account_validity_period is not None
expiration_ts = now + self._account_validity_period expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user( await self.store.set_account_validity_for_user(

View File

@ -131,6 +131,8 @@ class ApplicationServicesHandler:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels( synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender" "appservice_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
@ -166,6 +168,7 @@ class ApplicationServicesHandler:
if events: if events:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id) ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"appservice_sender" "appservice_sender"

View File

@ -244,8 +244,8 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
) )
# The number of seconds to keep a UI auth session active. # The number of seconds to keep a UI auth session active.
@ -255,14 +255,14 @@ class AuthHandler(BaseHandler):
self._failed_login_attempts_ratelimiter = Ratelimiter( self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
) )
self._clock = self.hs.get_clock() self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time. # Expire old UI auth sessions after a period of time.
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call( self._clock.looping_call(
run_as_background_process, run_as_background_process,
5 * 60 * 1000, 5 * 60 * 1000,
@ -289,7 +289,7 @@ class AuthHandler(BaseHandler):
hs.config.sso_account_deactivated_template hs.config.sso_account_deactivated_template
) )
self._server_name = hs.config.server_name self._server_name = hs.config.server.server_name
# cast to tuple for use with str.startswith # cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@ -749,7 +749,7 @@ class AuthHandler(BaseHandler):
"name": self.hs.config.user_consent_policy_name, "name": self.hs.config.user_consent_policy_name,
"url": "%s_matrix/consent?v=%s" "url": "%s_matrix/consent?v=%s"
% ( % (
self.hs.config.public_baseurl, self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version, self.hs.config.user_consent_version,
), ),
}, },
@ -1799,7 +1799,7 @@ class MacaroonGenerator:
def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key, key=self.hs.config.macaroon_secret_key,
) )

View File

@ -46,7 +46,7 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where # Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do). # it left off (if it has work left to do).
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting) hs.get_reactor().callWhenRunning(self._start_user_parting)
self._account_validity_enabled = ( self._account_validity_enabled = (
@ -131,7 +131,7 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.add_user_pending_deactivation(user_id) await self.store.add_user_pending_deactivation(user_id)
# delete from user directory # delete from user directory
await self.user_directory_handler.handle_user_deactivated(user_id) await self.user_directory_handler.handle_local_user_deactivated(user_id)
# Mark the user as erased, if they asked for that # Mark the user as erased, if they asked for that
if erase_data: if erase_data:

View File

@ -84,8 +84,8 @@ class DeviceMessageHandler:
self._ratelimiter = Ratelimiter( self._ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second, rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count, burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
) )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View File

@ -57,7 +57,7 @@ class E2eKeysHandler:
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
self._is_master = hs.config.worker_app is None self._is_master = hs.config.worker.worker_app is None
if not self._is_master: if not self._is_master:
self._user_device_resync_client = ( self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs) ReplicationUserDevicesResyncRestServlet.make_client(hs)

View File

@ -101,7 +101,7 @@ class FederationHandler(BaseHandler):
hs hs
) )
if hs.config.worker_app: if hs.config.worker.worker_app:
self._maybe_store_room_on_outlier_membership = ( self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
) )
@ -1614,7 +1614,7 @@ class FederationHandler(BaseHandler):
Args: Args:
room_id room_id
""" """
if self.config.worker_app: if self.config.worker.worker_app:
await self._clean_room_for_join_client(room_id) await self._clean_room_for_join_client(room_id)
else: else:
await self.store.clean_room_for_join(room_id) await self.store.clean_room_for_join(room_id)

View File

@ -149,7 +149,7 @@ class FederationEventHandler:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
if hs.config.worker_app: if hs.config.worker.worker_app:
self._user_device_resync = ( self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs) ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
@ -1009,7 +1009,7 @@ class FederationEventHandler:
await self._store.mark_remote_user_device_cache_as_stale(sender) await self._store.mark_remote_user_device_cache_as_stale(sender)
# Immediately attempt a resync in the background # Immediately attempt a resync in the background
if self._config.worker_app: if self._config.worker.worker_app:
await self._user_device_resync(user_id=sender) await self._user_device_resync(user_id=sender)
else: else:
await self._device_list_updater.user_device_resync(sender) await self._device_list_updater.user_device_resync(sender)

View File

@ -540,13 +540,13 @@ class IdentityHandler(BaseHandler):
# It is already checked that public_baseurl is configured since this code # It is already checked that public_baseurl is configured since this code
# should only be used if account_threepid_delegate_msisdn is true. # should only be used if account_threepid_delegate_msisdn is true.
assert self.hs.config.public_baseurl assert self.hs.config.server.public_baseurl
# we need to tell the client to send the token back to us, since it doesn't # we need to tell the client to send the token back to us, since it doesn't
# otherwise know where to send it, so add submit_url response parameter # otherwise know where to send it, so add submit_url response parameter
# (see also MSC2078) # (see also MSC2078)
data["submit_url"] = ( data["submit_url"] = (
self.hs.config.public_baseurl self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/add_threepid/msisdn/submit_token" + "_matrix/client/unstable/add_threepid/msisdn/submit_token"
) )
return data return data

View File

@ -84,7 +84,7 @@ class MessageHandler:
# scheduled. # scheduled.
self._scheduled_expiry: Optional[IDelayedCall] = None self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app: if not hs.config.worker.worker_app:
run_as_background_process( run_as_background_process(
"_schedule_next_expiry", self._schedule_next_expiry "_schedule_next_expiry", self._schedule_next_expiry
) )
@ -461,7 +461,7 @@ class EventCreationHandler:
self._dummy_events_threshold = hs.config.dummy_events_threshold self._dummy_events_threshold = hs.config.dummy_events_threshold
if ( if (
self.config.run_background_tasks self.config.worker.run_background_tasks
and self.config.cleanup_extremities_with_dummy_events and self.config.cleanup_extremities_with_dummy_events
): ):
self.clock.looping_call( self.clock.looping_call(

View File

@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._server_name: str = hs.config.server_name self._server_name: str = hs.config.server.server_name
# identifier for the external_ids table # identifier for the external_ids table
self.idp_id = provider.idp_id self.idp_id = provider.idp_id

View File

@ -91,7 +91,7 @@ class PaginationHandler:
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
if hs.config.run_background_tasks and hs.config.retention_enabled: if hs.config.worker.run_background_tasks and hs.config.retention_enabled:
# Run the purge jobs described in the configuration file. # Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs: for job in hs.config.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job) logger.info("Setting up purge job with config: %s", job)

View File

@ -28,6 +28,7 @@ from bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
self.wheel_timer = WheelTimer() self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
prev_state = await self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()} new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE: if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE new_fields["state"] = PresenceState.ONLINE

View File

@ -63,7 +63,7 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self.clock.looping_call( self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
) )

View File

@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler): class ReadMarkerHandler(BaseHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler() self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")

View File

@ -29,7 +29,7 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()

View File

@ -102,7 +102,7 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
if hs.config.worker_app: if hs.config.worker.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client( self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs hs
@ -696,7 +696,7 @@ class RegistrationHandler(BaseHandler):
address: the IP address used to perform the registration. address: the IP address used to perform the registration.
shadow_banned: Whether to shadow-ban the user shadow_banned: Whether to shadow-ban the user
""" """
if self.hs.config.worker_app: if self.hs.config.worker.worker_app:
await self._register_client( await self._register_client(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
@ -786,7 +786,7 @@ class RegistrationHandler(BaseHandler):
Does the bits that need doing on the main process. Not for use outside this Does the bits that need doing on the main process. Not for use outside this
class and RegisterDeviceReplicationServlet. class and RegisterDeviceReplicationServlet.
""" """
assert not self.hs.config.worker_app assert not self.hs.config.worker.worker_app
valid_until_ms = None valid_until_ms = None
if self.session_lifetime is not None: if self.session_lifetime is not None:
if is_guest: if is_guest:
@ -843,7 +843,7 @@ class RegistrationHandler(BaseHandler):
""" """
# TODO: 3pid registration can actually happen on the workers. Consider # TODO: 3pid registration can actually happen on the workers. Consider
# refactoring it. # refactoring it.
if self.hs.config.worker_app: if self.hs.config.worker.worker_app:
await self._post_registration_client( await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token user_id=user_id, auth_result=auth_result, access_token=access_token
) )

View File

@ -33,6 +33,7 @@ from synapse.api.constants import (
Membership, Membership,
RoomCreationPreset, RoomCreationPreset,
RoomEncryptionAlgorithms, RoomEncryptionAlgorithms,
RoomTypes,
) )
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -397,7 +398,7 @@ class RoomCreationHandler(BaseHandler):
initial_state = {} initial_state = {}
# Replicate relevant room events # Replicate relevant room events
types_to_copy = ( types_to_copy: List[Tuple[str, Optional[str]]] = [
(EventTypes.JoinRules, ""), (EventTypes.JoinRules, ""),
(EventTypes.Name, ""), (EventTypes.Name, ""),
(EventTypes.Topic, ""), (EventTypes.Topic, ""),
@ -408,7 +409,16 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.ServerACL, ""), (EventTypes.ServerACL, ""),
(EventTypes.RelatedGroups, ""), (EventTypes.RelatedGroups, ""),
(EventTypes.PowerLevels, ""), (EventTypes.PowerLevels, ""),
) ]
# If the old room was a space, copy over the room type and the rooms in
# the space.
if (
old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
== RoomTypes.SPACE
):
creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE
types_to_copy.append((EventTypes.SpaceChild, None))
old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy) old_room_id, StateFilter.from_types(types_to_copy)
@ -419,6 +429,11 @@ class RoomCreationHandler(BaseHandler):
for k, old_event_id in old_room_state_ids.items(): for k, old_event_id in old_room_state_ids.items():
old_event = old_room_state_events.get(old_event_id) old_event = old_room_state_events.get(old_event_id)
if old_event: if old_event:
# If the event is an space child event with empty content, it was
# removed from the space and should be ignored.
if k[0] == EventTypes.SpaceChild and not old_event.content:
continue
initial_state[k] = old_event.content initial_state[k] = old_event.content
# deep-copy the power-levels event before we start modifying it # deep-copy the power-levels event before we start modifying it

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@ -21,6 +22,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MatchChange(Enum):
no_change = auto()
now_true = auto()
now_false = auto()
class StateDeltasHandler: class StateDeltasHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -31,18 +38,12 @@ class StateDeltasHandler:
event_id: Optional[str], event_id: Optional[str],
key_name: str, key_name: str,
public_value: str, public_value: str,
) -> Optional[bool]: ) -> MatchChange:
"""Given two events check if the `key_name` field in content changed """Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so. from not matching `public_value` to doing so.
For example, check if `history_visibility` (`key_name`) changed from For example, check if `history_visibility` (`key_name`) changed from
`shared` to `world_readable` (`public_value`). `shared` to `world_readable` (`public_value`).
Returns:
None if the field in the events either both match `public_value`
or if neither do, i.e. there has been no change.
True if it didn't match `public_value` but now does
False if it did match `public_value` but now doesn't
""" """
prev_event = None prev_event = None
event = None event = None
@ -54,7 +55,7 @@ class StateDeltasHandler:
if not event and not prev_event: if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id) logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
return None return MatchChange.no_change
prev_value = None prev_value = None
value = None value = None
@ -68,8 +69,8 @@ class StateDeltasHandler:
logger.debug("prev_value: %r -> value: %r", prev_value, value) logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value: if value == public_value and prev_value != public_value:
return True return MatchChange.now_true
elif value != public_value and prev_value == public_value: elif value != public_value and prev_value == public_value:
return False return MatchChange.now_false
else: else:
return None return MatchChange.no_change

View File

@ -54,7 +54,7 @@ class StatsHandler:
# Guard to ensure we only process deltas one at a time # Guard to ensure we only process deltas one at a time
self._is_processing = False self._is_processing = False
if self.stats_enabled and hs.config.run_background_tasks: if self.stats_enabled and hs.config.worker.run_background_tasks:
self.notifier.add_replication_callback(self.notify_new_event) self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before # We kick this off so that we don't have to wait for a change before

View File

@ -53,7 +53,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -73,7 +73,7 @@ class FollowerTypingHandler:
self._room_typing: Dict[str, Set[str]] = {} self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke: Dict[RoomMember, int] = {} self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0 self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000) self.clock.looping_call(self._handle_timeouts, 5000)

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict from synapse.types import JsonDict
@ -30,14 +30,26 @@ logger = logging.getLogger(__name__)
class UserDirectoryHandler(StateDeltasHandler): class UserDirectoryHandler(StateDeltasHandler):
"""Handles querying of and keeping updated the user_directory. """Handles queries and updates for the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
The user directory is filled with users who this server can see are joined to a When a local user searches the user_directory, we report two kinds of users:
world_readable or publicly joinable room. We keep a database table up to date
by streaming changes of the current state and recalculating whether users should - users this server can see are joined to a world_readable or publicly
be in the directory or not when necessary. joinable room, and
- users belonging to a private room shared by that local user.
The two cases are tracked separately in the `users_in_public_rooms` and
`users_who_share_private_rooms` tables. Both kinds of users have their
username and avatar tracked in a `user_directory` table.
This handler has three responsibilities:
1. Forwarding requests to `/user_directory/search` to the UserDirectoryStore.
2. Providing hooks for the application to call when local users are added,
removed, or have their profile changed.
3. Listening for room state changes that indicate remote users have
joined or left a room, or that their profile has changed.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -130,7 +142,7 @@ class UserDirectoryHandler(StateDeltasHandler):
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
async def handle_user_deactivated(self, user_id: str) -> None: async def handle_local_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated""" """Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all # FIXME(#3714): We should probably do this in the same worker as all
# the other changes. # the other changes.
@ -196,7 +208,7 @@ class UserDirectoryHandler(StateDeltasHandler):
public_value=Membership.JOIN, public_value=Membership.JOIN,
) )
if change is False: if change is MatchChange.now_false:
# Need to check if the server left the room entirely, if so # Need to check if the server left the room entirely, if so
# we might need to remove all the users in that room # we might need to remove all the users in that room
is_in_room = await self.store.is_host_joined( is_in_room = await self.store.is_host_joined(
@ -219,14 +231,14 @@ class UserDirectoryHandler(StateDeltasHandler):
is_support = await self.store.is_support_user(state_key) is_support = await self.store.is_support_user(state_key)
if not is_support: if not is_support:
if change is None: if change is MatchChange.no_change:
# Handle any profile changes # Handle any profile changes
await self._handle_profile_change( await self._handle_profile_change(
state_key, room_id, prev_event_id, event_id state_key, room_id, prev_event_id, event_id
) )
continue continue
if change: # The user joined if change is MatchChange.now_true: # The user joined
event = await self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)
# It isn't expected for this event to not exist, but we # It isn't expected for this event to not exist, but we
# don't want the entire background process to break. # don't want the entire background process to break.
@ -263,14 +275,14 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug("Handling change for %s: %s", typ, room_id) logger.debug("Handling change for %s: %s", typ, room_id)
if typ == EventTypes.RoomHistoryVisibility: if typ == EventTypes.RoomHistoryVisibility:
change = await self._get_key_change( publicness = await self._get_key_change(
prev_event_id, prev_event_id,
event_id, event_id,
key_name="history_visibility", key_name="history_visibility",
public_value=HistoryVisibility.WORLD_READABLE, public_value=HistoryVisibility.WORLD_READABLE,
) )
elif typ == EventTypes.JoinRules: elif typ == EventTypes.JoinRules:
change = await self._get_key_change( publicness = await self._get_key_change(
prev_event_id, prev_event_id,
event_id, event_id,
key_name="join_rule", key_name="join_rule",
@ -278,9 +290,7 @@ class UserDirectoryHandler(StateDeltasHandler):
) )
else: else:
raise Exception("Invalid event type") raise Exception("Invalid event type")
# If change is None, no change. True => become world_readable/public, if publicness is MatchChange.no_change:
# False => was world_readable/public
if change is None:
logger.debug("No change") logger.debug("No change")
return return
@ -290,13 +300,13 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id room_id
) )
logger.debug("Change: %r, is_public: %r", change, is_public) logger.debug("Change: %r, publicness: %r", publicness, is_public)
if change and not is_public: if publicness is MatchChange.now_true and not is_public:
# If we became world readable but room isn't currently public then # If we became world readable but room isn't currently public then
# we ignore the change # we ignore the change
return return
elif not change and is_public: elif publicness is MatchChange.now_false and is_public:
# If we stopped being world readable but are still public, # If we stopped being world readable but are still public,
# ignore the change # ignore the change
return return

View File

@ -236,8 +236,17 @@ except ImportError:
try: try:
from rust_python_jaeger_reporter import Reporter from rust_python_jaeger_reporter import Reporter
# jaeger-client 4.7.0 requires that reporters inherit from BaseReporter, which
# didn't exist before that version.
try:
from jaeger_client.reporter import BaseReporter
except ImportError:
class BaseReporter: # type: ignore[no-redef]
pass
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class _WrappedRustReporter: class _WrappedRustReporter(BaseReporter):
"""Wrap the reporter to ensure `report_span` never throws.""" """Wrap the reporter to ensure `report_span` never throws."""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@ -374,7 +383,7 @@ def init_tracer(hs: "HomeServer"):
config = JaegerConfig( config = JaegerConfig(
config=hs.config.jaeger_config, config=hs.config.jaeger_config,
service_name=f"{hs.config.server_name} {hs.get_instance_name()}", service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config), scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(), metrics_factory=PrometheusMetricsFactory(),
) )
@ -382,6 +391,7 @@ def init_tracer(hs: "HomeServer"):
# If we have the rust jaeger reporter available let's use that. # If we have the rust jaeger reporter available let's use that.
if RustReporter: if RustReporter:
logger.info("Using rust_python_jaeger_reporter library") logger.info("Using rust_python_jaeger_reporter library")
assert config.sampler is not None
tracer = config.create_tracer(RustReporter(), config.sampler) tracer = config.create_tracer(RustReporter(), config.sampler)
opentracing.set_global_tracer(tracer) opentracing.set_global_tracer(tracer)
else: else:

View File

@ -178,7 +178,7 @@ class ModuleApi:
@property @property
def public_baseurl(self) -> str: def public_baseurl(self) -> str:
"""The configured public base URL for this homeserver.""" """The configured public base URL for this homeserver."""
return self._hs.config.public_baseurl return self._hs.config.server.public_baseurl
@property @property
def email_app_name(self) -> str: def email_app_name(self) -> str:
@ -640,7 +640,7 @@ class ModuleApi:
if desc is None: if desc is None:
desc = f.__name__ desc = f.__name__
if self._hs.config.run_background_tasks or run_on_all_instances: if self._hs.config.worker.run_background_tasks or run_on_all_instances:
self._clock.looping_call( self._clock.looping_call(
run_as_background_process, run_as_background_process,
msec, msec,

View File

@ -130,7 +130,7 @@ class Mailer:
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
self.hs.config.public_baseurl self.hs.config.server.public_baseurl
+ "_synapse/client/password_reset/email/submit_token?%s" + "_synapse/client/password_reset/email/submit_token?%s"
% urllib.parse.urlencode(params) % urllib.parse.urlencode(params)
) )
@ -140,7 +140,7 @@ class Mailer:
await self.send_email( await self.send_email(
email_address, email_address,
self.email_subjects.password_reset self.email_subjects.password_reset
% {"server_name": self.hs.config.server_name}, % {"server_name": self.hs.config.server.server_name},
template_vars, template_vars,
) )
@ -160,7 +160,7 @@ class Mailer:
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
self.hs.config.public_baseurl self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/registration/email/submit_token?%s" + "_matrix/client/unstable/registration/email/submit_token?%s"
% urllib.parse.urlencode(params) % urllib.parse.urlencode(params)
) )
@ -170,7 +170,7 @@ class Mailer:
await self.send_email( await self.send_email(
email_address, email_address,
self.email_subjects.email_validation self.email_subjects.email_validation
% {"server_name": self.hs.config.server_name}, % {"server_name": self.hs.config.server.server_name},
template_vars, template_vars,
) )
@ -191,7 +191,7 @@ class Mailer:
""" """
params = {"token": token, "client_secret": client_secret, "sid": sid} params = {"token": token, "client_secret": client_secret, "sid": sid}
link = ( link = (
self.hs.config.public_baseurl self.hs.config.server.public_baseurl
+ "_matrix/client/unstable/add_threepid/email/submit_token?%s" + "_matrix/client/unstable/add_threepid/email/submit_token?%s"
% urllib.parse.urlencode(params) % urllib.parse.urlencode(params)
) )
@ -201,7 +201,7 @@ class Mailer:
await self.send_email( await self.send_email(
email_address, email_address,
self.email_subjects.email_validation self.email_subjects.email_validation
% {"server_name": self.hs.config.server_name}, % {"server_name": self.hs.config.server.server_name},
template_vars, template_vars,
) )
@ -852,7 +852,7 @@ class Mailer:
# XXX: make r0 once API is stable # XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % ( return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl, self.hs.config.server.public_baseurl,
urllib.parse.urlencode(params), urllib.parse.urlencode(params),
) )

View File

@ -73,7 +73,7 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
): ):
self.client_name = client_name self.client_name = client_name
self.command_handler = command_handler self.command_handler = command_handler
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
self.hs = hs self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class self._clock = hs.get_clock() # As self.clock is defined in super class

View File

@ -168,7 +168,7 @@ class ReplicationCommandHandler:
continue continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker_app is not None: if hs.config.worker.worker_app is not None:
continue continue
if stream.NAME == FederationStream.NAME and hs.config.send_federation: if stream.NAME == FederationStream.NAME and hs.config.send_federation:
@ -222,7 +222,7 @@ class ReplicationCommandHandler:
}, },
) )
self._is_master = hs.config.worker_app is None self._is_master = hs.config.worker.worker_app is None
self._federation_sender = None self._federation_sender = None
if self._is_master and not hs.config.send_federation: if self._is_master and not hs.config.send_federation:

View File

@ -40,7 +40,7 @@ class ReplicationStreamProtocolFactory(Factory):
def __init__(self, hs): def __init__(self, hs):
self.command_handler = hs.get_tcp_replication() self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
# If we've created a `ReplicationStreamProtocolFactory` then we're # If we've created a `ReplicationStreamProtocolFactory` then we're
# almost certainly registering a replication listener, so let's ensure # almost certainly registering a replication listener, so let's ensure

View File

@ -42,7 +42,7 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow ROW_TYPE = FederationStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
if hs.config.worker_app is None: if hs.config.worker.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue. # master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender # (if the master is configured to send federation itself, federation_sender
# will be a real FederationSender, which has stubs for current_token and # will be a real FederationSender, which has stubs for current_token and

View File

@ -247,7 +247,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process. # Some servlets only get registered for the main process.
if hs.config.worker_app is None: if hs.config.worker.worker_app is None:
SendServerNoticeServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server)

View File

@ -68,7 +68,10 @@ class AuthRestServlet(RestServlet):
html = self.terms_template.render( html = self.terms_template.render(
session=session, session=session,
terms_url="%s_matrix/consent?v=%s" terms_url="%s_matrix/consent?v=%s"
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version), % (
self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version,
),
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS), % (CLIENT_API_PREFIX, LoginType.TERMS),
) )
@ -135,7 +138,7 @@ class AuthRestServlet(RestServlet):
session=session, session=session,
terms_url="%s_matrix/consent?v=%s" terms_url="%s_matrix/consent?v=%s"
% ( % (
self.hs.config.public_baseurl, self.hs.config.server.public_baseurl,
self.hs.config.user_consent_version, self.hs.config.user_consent_version,
), ),
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/r0/auth/%s/fallback/web"

View File

@ -93,14 +93,14 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
store=hs.get_datastore(), store=hs.get_datastore(),
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
) )
self._account_ratelimiter = Ratelimiter( self._account_ratelimiter = Ratelimiter(
store=hs.get_datastore(), store=hs.get_datastore(),
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second, rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count, burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
) )
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
@ -486,7 +486,7 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler. # register themselves with the main SSOHandler.
_load_sso_handlers(hs) _load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._public_baseurl = hs.config.public_baseurl self._public_baseurl = hs.config.server.public_baseurl
async def on_GET( async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None self, request: SynapseRequest, idp_id: Optional[str] = None

View File

@ -69,7 +69,7 @@ class IdTokenServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server.server_name
async def on_POST( async def on_POST(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str

View File

@ -59,7 +59,7 @@ class PushRuleRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker.worker_app is not None
self._users_new_default_push_rules = hs.config.users_new_default_push_rules self._users_new_default_push_rules = hs.config.users_new_default_push_rules

View File

@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
# Artificially delay requests if rate > sleep_limit/window_size # Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1, sleep_limit=1,
# Amount of artificial delay to apply # Amount of artificial delay to apply
sleep_msec=1000, sleep_delay=1000,
# Error with 429 if more than reject_limit requests are queued # Error with 429 if more than reject_limit requests are queued
reject_limit=1, reject_limit=1,
# Allow 1 request at a time # Allow 1 request at a time
concurrent_requests=1, concurrent=1,
), ),
) )
@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
Returns: Returns:
dictionary for response from /register dictionary for response from /register
""" """
result = {"user_id": user_id, "home_server": self.hs.hostname} result: JsonDict = {
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True user_id, device_id, initial_display_name, is_guest=True
) )
result = { result: JsonDict = {
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"access_token": access_token, "access_token": access_token,

View File

@ -388,7 +388,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name: if server and server != self.hs.config.server.server_name:
# Ensure the server is valid. # Ensure the server is valid.
try: try:
parse_and_validate_server_name(server) parse_and_validate_server_name(server)
@ -438,7 +438,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name: if server and server != self.hs.config.server.server_name:
# Ensure the server is valid. # Ensure the server is valid.
try: try:
parse_and_validate_server_name(server) parse_and_validate_server_name(server)

View File

@ -86,12 +86,12 @@ class LocalKey(Resource):
json_object = { json_object = {
"valid_until_ts": self.valid_until_ts, "valid_until_ts": self.valid_until_ts,
"server_name": self.config.server_name, "server_name": self.config.server.server_name,
"verify_keys": verify_keys, "verify_keys": verify_keys,
"old_verify_keys": old_verify_keys, "old_verify_keys": old_verify_keys,
} }
for key in self.config.signing_key: for key in self.config.signing_key:
json_object = sign_json(json_object, self.config.server_name, key) json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object return json_object
def render_GET(self, request): def render_GET(self, request):

View File

@ -224,7 +224,9 @@ class RemoteKey(DirectServeJsonResource):
for key_json in json_results: for key_json in json_results:
key_json = json_decoder.decode(key_json.decode("utf-8")) key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys: for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key) key_json = sign_json(
key_json, self.config.server.server_name, signing_key
)
signed_keys.append(key_json) signed_keys.append(key_json)

View File

@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: Request) -> None:
try: try:

View File

@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: Request) -> None:
try: try:

View File

@ -34,10 +34,10 @@ class WellKnownBuilder:
def get_well_known(self): def get_well_known(self):
# if we don't have a public_baseurl, we can't help much here. # if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None: if self._config.server.public_baseurl is None:
return None return None
result = {"m.homeserver": {"base_url": self._config.public_baseurl}} result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}}
if self._config.default_identity_server: if self._config.default_identity_server:
result["m.identity_server"] = { result["m.identity_server"] = {

View File

@ -313,7 +313,7 @@ class HomeServer(metaclass=abc.ABCMeta):
# Register background tasks required by this server. This must be done # Register background tasks required by this server. This must be done
# somewhat manually due to the background tasks not being registered # somewhat manually due to the background tasks not being registered
# unless handlers are instantiated. # unless handlers are instantiated.
if self.config.run_background_tasks: if self.config.worker.run_background_tasks:
self.setup_background_tasks() self.setup_background_tasks()
def start_listening(self) -> None: def start_listening(self) -> None:
@ -370,8 +370,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return Ratelimiter( return Ratelimiter(
store=self.get_datastore(), store=self.get_datastore(),
clock=self.get_clock(), clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second, rate_hz=self.config.ratelimiting.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count, burst_count=self.config.ratelimiting.rc_registration.burst_count,
) )
@cache_in_self @cache_in_self
@ -498,7 +498,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_device_handler(self): def get_device_handler(self):
if self.config.worker_app: if self.config.worker.worker_app:
return DeviceWorkerHandler(self) return DeviceWorkerHandler(self)
else: else:
return DeviceHandler(self) return DeviceHandler(self)
@ -621,7 +621,7 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_federation_sender(self) -> AbstractFederationSender: def get_federation_sender(self) -> AbstractFederationSender:
if self.should_send_federation(): if self.should_send_federation():
return FederationSender(self) return FederationSender(self)
elif not self.config.worker_app: elif not self.config.worker.worker_app:
return FederationRemoteSendQueue(self) return FederationRemoteSendQueue(self)
else: else:
raise Exception("Workers cannot send federation traffic") raise Exception("Workers cannot send federation traffic")
@ -650,14 +650,14 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_groups_local_handler( def get_groups_local_handler(
self, self,
) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app: if self.config.worker.worker_app:
return GroupsLocalWorkerHandler(self) return GroupsLocalWorkerHandler(self)
else: else:
return GroupsLocalHandler(self) return GroupsLocalHandler(self)
@cache_in_self @cache_in_self
def get_groups_server_handler(self): def get_groups_server_handler(self):
if self.config.worker_app: if self.config.worker.worker_app:
return GroupsServerWorkerHandler(self) return GroupsServerWorkerHandler(self)
else: else:
return GroupsServerHandler(self) return GroupsServerHandler(self)
@ -684,7 +684,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_room_member_handler(self) -> RoomMemberHandler: def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker_app: if self.config.worker.worker_app:
return RoomMemberWorkerHandler(self) return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self) return RoomMemberMasterHandler(self)
@ -694,13 +694,13 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_server_notices_manager(self) -> ServerNoticesManager: def get_server_notices_manager(self) -> ServerNoticesManager:
if self.config.worker_app: if self.config.worker.worker_app:
raise Exception("Workers cannot send server notices") raise Exception("Workers cannot send server notices")
return ServerNoticesManager(self) return ServerNoticesManager(self)
@cache_in_self @cache_in_self
def get_server_notices_sender(self) -> WorkerServerNoticesSender: def get_server_notices_sender(self) -> WorkerServerNoticesSender:
if self.config.worker_app: if self.config.worker.worker_app:
return WorkerServerNoticesSender(self) return WorkerServerNoticesSender(self)
return ServerNoticesSender(self) return ServerNoticesSender(self)
@ -766,7 +766,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter: def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) return FederationRateLimiter(
self.get_clock(), config=self.config.ratelimiting.rc_federation
)
@cache_in_self @cache_in_self
def get_module_api(self) -> ModuleApi: def get_module_api(self) -> ModuleApi:

View File

@ -271,7 +271,7 @@ class DataStore(
def get_users_paginate_txn(txn): def get_users_paginate_txn(txn):
filters = [] filters = []
args = [self.hs.config.server_name] args = [self.hs.config.server.server_name]
# Set ordering # Set ordering
order_by_column = UserSortOrder(order_by).value order_by_column = UserSortOrder(order_by).value
@ -356,13 +356,13 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig
return return
user_domain = get_domain_from_id(rows[0][0]) user_domain = get_domain_from_id(rows[0][0])
if user_domain == config.server_name: if user_domain == config.server.server_name:
return return
raise Exception( raise Exception(
"Found users in database not native to %s!\n" "Found users in database not native to %s!\n"
"You cannot change a synapse server_name after it's been configured" "You cannot change a synapse server_name after it's been configured"
% (config.server_name,) % (config.server.server_name,)
) )

View File

@ -35,7 +35,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if ( if (
hs.config.run_background_tasks hs.config.worker.run_background_tasks
and self.hs.config.redaction_retention_period is not None and self.hs.config.redaction_retention_period is not None
): ):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)

View File

@ -355,7 +355,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
self.user_ips_max_age = hs.config.user_ips_max_age self.user_ips_max_age = hs.config.user_ips_max_age
if hs.config.run_background_tasks and self.user_ips_max_age: if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")

View File

@ -51,7 +51,7 @@ class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call( self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
) )

View File

@ -62,7 +62,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call( hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
) )

View File

@ -82,7 +82,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3 self._rotate_delay = 3
self._rotate_count = 10000 self._rotate_count = 10000
self._doing_notif_rotation = False self._doing_notif_rotation = False
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._rotate_notif_loop = self._clock.looping_call( self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 60 * 1000 self._rotate_notifs, 30 * 60 * 1000
) )

View File

@ -158,7 +158,7 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1
) )
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings # We periodically clean out old transaction ID mappings
self._clock.looping_call( self._clock.looping_call(
self._cleanup_old_transaction_ids, self._cleanup_old_transaction_ids,

View File

@ -56,7 +56,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes # Read the extrems every 60 minutes
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
# Used in _generate_user_daily_visits to keep track of progress # Used in _generate_user_daily_visits to keep track of progress

View File

@ -132,14 +132,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
hs.config.account_validity.account_validity_startup_job_max_delta hs.config.account_validity.account_validity_startup_job_max_delta
) )
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.call_later( self._clock.call_later(
0.0, 0.0,
self._set_expiration_date_when_missing, self._set_expiration_date_when_missing,
) )
# Create a background job for culling expired 3PID validity tokens # Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call( self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
) )
@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period. delta equal to 10% of the validity period.
""" """
now_ms = self._clock.time_msec() now_ms = self._clock.time_msec()
assert self._account_validity_period is not None
expiration_ts = now_ms + self._account_validity_period expiration_ts = now_ms + self._account_validity_period
if use_delta: if use_delta:

View File

@ -815,7 +815,7 @@ class RoomWorkerStore(SQLBaseStore):
If it is `None` media will be removed from quarantine If it is `None` media will be removed from quarantine
""" """
logger.info("Quarantining media: %s/%s", server_name, media_id) logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server_name is_local = server_name == self.config.server.server_name
def _quarantine_media_by_id_txn(txn): def _quarantine_media_by_id_txn(txn):
local_mxcs = [media_id] if is_local else [] local_mxcs = [media_id] if is_local else []

View File

@ -81,7 +81,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.close() txn.close()
if ( if (
self.hs.config.run_background_tasks self.hs.config.worker.run_background_tasks
and self.hs.config.metrics_flags.known_servers and self.hs.config.metrics_flags.known_servers
): ):
self._known_servers_count = 1 self._known_servers_count = 1
@ -196,6 +196,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) -> Dict[str, ProfileInfo]: ) -> Dict[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room. """Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
events, and so may be specific to this room rather than part of a user's
global profile. To avoid privacy leaks, the profile data should only be
revealed to users who are already in this room.
Args: Args:
room_id: The ID of the room to retrieve the users of. room_id: The ID of the room to retrieve the users of.

View File

@ -48,7 +48,7 @@ class SessionStore(SQLBaseStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions. # Create a background job for culling expired sessions.
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session( async def create_session(

View File

@ -672,7 +672,7 @@ class StatsStore(StateDeltasStore):
def get_users_media_usage_paginate_txn(txn): def get_users_media_usage_paginate_txn(txn):
filters = [] filters = []
args = [self.hs.config.server_name] args = [self.hs.config.server.server_name]
if search_term: if search_term:
filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)")

View File

@ -60,7 +60,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.run_background_tasks: if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
@wrap_as_background_process("cleanup_transactions") @wrap_as_background_process("cleanup_transactions")

View File

@ -196,7 +196,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
users_with_profile = await self.get_users_in_room_with_profiles(room_id) users_with_profile = await self.get_users_in_room_with_profiles(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory. # Update each user in the user directory.
for user_id, profile in users_with_profile.items(): for user_id, profile in users_with_profile.items():
@ -207,7 +206,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert = set() to_insert = set()
if is_public: if is_public:
for user_id in user_ids: for user_id in users_with_profile:
if self.get_if_app_services_interested_in_user(user_id): if self.get_if_app_services_interested_in_user(user_id):
continue continue
@ -217,14 +216,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
await self.add_users_in_public_rooms(room_id, to_insert) await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear() to_insert.clear()
else: else:
for user_id in user_ids: for user_id in users_with_profile:
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
continue continue
if self.get_if_app_services_interested_in_user(user_id): if self.get_if_app_services_interested_in_user(user_id):
continue continue
for other_user_id in user_ids: for other_user_id in users_with_profile:
if user_id == other_user_id: if user_id == other_user_id:
continue continue
@ -511,7 +510,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
self._prefer_local_users_in_search = ( self._prefer_local_users_in_search = (
hs.config.user_directory_search_prefer_local_users hs.config.user_directory_search_prefer_local_users
) )
self._server_name = hs.config.server_name self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None: async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn): def _remove_from_user_dir_txn(txn):

View File

@ -134,7 +134,7 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple # if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once. # workers doing it at once.
if ( if (
config.worker_app is not None config.worker.worker_app is not None
and version_info.current_version != SCHEMA_VERSION and version_info.current_version != SCHEMA_VERSION
): ):
raise UpgradeDatabaseException( raise UpgradeDatabaseException(
@ -154,7 +154,7 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple # if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once. # workers doing it at once.
if config and config.worker_app is not None: if config and config.worker.worker_app is not None:
raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR) raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR)
_setup_new_database(cur, database_engine, databases=databases) _setup_new_database(cur, database_engine, databases=databases)
@ -355,7 +355,7 @@ def _upgrade_existing_database(
else: else:
assert config assert config
is_worker = config and config.worker_app is not None is_worker = config and config.worker.worker_app is not None
if ( if (
current_schema_state.compat_version is not None current_schema_state.compat_version is not None

View File

@ -38,7 +38,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config") logger.warning("Could not get app_service_config_files from config")
pass pass
appservices = load_appservices(config.server_name, config_files) appservices = load_appservices(config.server.server_name, config_files)
owned = {} owned = {}

View File

@ -67,7 +67,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id) INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ? WHERE type = 'm.room.member' AND state_key LIKE ?
""" """
cur.execute(sql, ("%:" + config.server_name,)) cur.execute(sql, ("%:" + config.server.server_name,))
cur.execute( cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"

View File

@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
IReactorCore, IReactorCore,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTCP, IReactorTCP,
IReactorThreads,
IReactorTime, IReactorTime,
) )
@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
# Note that this seems to require inheriting *directly* from Interface in order # Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface. # for mypy-zope to realize it is an interface.
class ISynapseReactor( class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface IReactorTCP,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,
IReactorThreads,
Interface,
): ):
"""The interfaces necessary for Synapse to function.""" """The interfaces necessary for Synapse to function."""

View File

@ -15,27 +15,35 @@
import json import json
import logging import logging
import re import re
from typing import Pattern import typing
from typing import Any, Callable, Dict, Generator, Pattern
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer, task from twisted.internet import defer, task
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.python.failure import Failure
from synapse.logging import context from synapse.logging import context
if typing.TYPE_CHECKING:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WILDCARD_RUN = re.compile(r"([\?\*]+)") _WILDCARD_RUN = re.compile(r"([\?\*]+)")
def _reject_invalid_json(val): def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON.""" """Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val) raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_frozendict(obj): def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes frozendicts serializable by returning """Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict the underlying dict
""" """
@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure): def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures. # defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError) failure.trap(defer.FirstError)
return failure.value.subFailure return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
@attr.s(slots=True) @attr.s(slots=True)
@ -75,25 +83,25 @@ class Clock:
reactor: The Twisted reactor to use. reactor: The Twisted reactor to use.
""" """
_reactor = attr.ib() _reactor: IReactorTime = attr.ib()
@defer.inlineCallbacks @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations
def sleep(self, seconds): def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
d = defer.Deferred() d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext(): with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds) self._reactor.callLater(seconds, d.callback, seconds)
res = yield d res = yield d
return res return res
def time(self): def time(self) -> float:
"""Returns the current system time in seconds since epoch.""" """Returns the current system time in seconds since epoch."""
return self._reactor.seconds() return self._reactor.seconds()
def time_msec(self): def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch.""" """Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs): def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
"""Call a function repeatedly. """Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time. Waits `msec` initially before calling `f` for the first time.
@ -102,8 +110,8 @@ class Clock:
other than trivial, you probably want to wrap it in run_as_background_process. other than trivial, you probably want to wrap it in run_as_background_process.
Args: Args:
f(function): The function to call repeatedly. f: The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds. msec: How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function. *args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function. **kwargs: Key arguments to pass to function.
""" """
@ -113,7 +121,7 @@ class Clock:
d.addErrback(log_failure, "Looping call died", consumeErrors=False) d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call return call
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
"""Call something later """Call something later
Note that the function will be called with no logcontext, so if it is anything Note that the function will be called with no logcontext, so if it is anything
@ -133,7 +141,7 @@ class Clock:
with context.PreserveLoggingContext(): with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try: try:
timer.cancel() timer.cancel()
except Exception: except Exception:

View File

@ -37,6 +37,7 @@ import attr
from typing_extensions import ContextManager from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.base import ReactorBase
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import failure from twisted.python import failure
@ -268,6 +269,7 @@ class Linearizer:
if not clock: if not clock:
from twisted.internet import reactor from twisted.internet import reactor
assert isinstance(reactor, ReactorBase)
clock = Clock(reactor) clock = Clock(reactor)
self._clock = clock self._clock = clock
self.max_count = max_count self.max_count = max_count
@ -411,7 +413,7 @@ class ReadWriteLock:
# writers and readers have been resolved. The new writer replaces the latest # writers and readers have been resolved. The new writer replaces the latest
# writer. # writer.
def __init__(self): def __init__(self) -> None:
# Latest readers queued # Latest readers queued
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
@ -503,7 +505,7 @@ def timeout_deferred(
timed_out = [False] timed_out = [False]
def time_it_out(): def time_it_out() -> None:
timed_out[0] = True timed_out[0] = True
try: try:
@ -550,19 +552,21 @@ def timeout_deferred(
return new_d return new_d
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DoneAwaitable: class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value.""" """Simple awaitable that returns the provided value."""
value = attr.ib() value = attr.ib(type=Any) # should be: R
def __await__(self): def __await__(self):
return self return self
def __iter__(self): def __iter__(self) -> "DoneAwaitable":
return self return self
def __next__(self): def __next__(self) -> None:
raise StopIteration(self.value) raise StopIteration(self.value)

View File

@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
# First we create a defer and add it and the value to the list of # First we create a defer and add it and the value to the list of
# pending items. # pending items.
d = defer.Deferred() d: defer.Deferred[R] = defer.Deferred()
self._next_values.setdefault(key, []).append((value, d)) self._next_values.setdefault(key, []).append((value, d))
# If we're not currently processing the key fire off a background # If we're not currently processing the key fire off a background

View File

@ -64,32 +64,32 @@ class CacheMetric:
evicted_size = attr.ib(default=0) evicted_size = attr.ib(default=0)
memory_usage = attr.ib(default=None) memory_usage = attr.ib(default=None)
def inc_hits(self): def inc_hits(self) -> None:
self.hits += 1 self.hits += 1
def inc_misses(self): def inc_misses(self) -> None:
self.misses += 1 self.misses += 1
def inc_evictions(self, size=1): def inc_evictions(self, size: int = 1) -> None:
self.evicted_size += size self.evicted_size += size
def inc_memory_usage(self, memory: int): def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None: if self.memory_usage is None:
self.memory_usage = 0 self.memory_usage = 0
self.memory_usage += memory self.memory_usage += memory
def dec_memory_usage(self, memory: int): def dec_memory_usage(self, memory: int) -> None:
self.memory_usage -= memory self.memory_usage -= memory
def clear_memory_usage(self): def clear_memory_usage(self) -> None:
if self.memory_usage is not None: if self.memory_usage is not None:
self.memory_usage = 0 self.memory_usage = 0
def describe(self): def describe(self):
return [] return []
def collect(self): def collect(self) -> None:
try: try:
if self._cache_type == "response_cache": if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache)) response_cache_size.labels(self._cache_name).set(len(self._cache))

View File

@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
TreeCache, "MutableMapping[KT, CacheEntry]" TreeCache, "MutableMapping[KT, CacheEntry]"
] = cache_type() ] = cache_type()
def metrics_cb(): def metrics_cb() -> None:
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than # cache is used for completed results and maps to the result itself, rather than
@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
def max_entries(self): def max_entries(self):
return self.cache.max_size return self.cache.max_size
def check_thread(self): def check_thread(self) -> None:
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
self.thread = threading.current_thread() self.thread = threading.current_thread()
@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def compare_and_pop(): def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and """Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it. if so, pop it.
@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
return False return False
def cb(result): def cb(result) -> None:
if compare_and_pop(): if compare_and_pop():
self.cache.set(key, result, entry.callbacks) self.cache.set(key, result, entry.callbacks)
else: else:
@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now. # not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
def eb(_fail): def eb(_fail) -> None:
compare_and_pop() compare_and_pop()
entry.invalidate() entry.invalidate()
@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
for entry in iterate_tree_cache_entry(entry): for entry in iterate_tree_cache_entry(entry):
entry.invalidate() entry.invalidate()
def invalidate_all(self): def invalidate_all(self) -> None:
self.check_thread() self.check_thread()
self.cache.clear() self.cache.clear()
for entry in self._pending_deferred_cache.values(): for entry in self._pending_deferred_cache.values():
@ -332,7 +332,7 @@ class CacheEntry:
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
self.invalidated = False self.invalidated = False
def invalidate(self): def invalidate(self) -> None:
if not self.invalidated: if not self.invalidated:
self.invalidated = True self.invalidated = True
for callback in self.callbacks: for callback in self.callbacks:

View File

@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
KT = TypeVar("KT") KT = TypeVar("KT")
# The type of the dictionary keys. # The type of the dictionary keys.
DKT = TypeVar("DKT") DKT = TypeVar("DKT")
# The type of the dictionary values.
DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True) @attr.s(slots=True)
class DictionaryEntry: class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache """Returned when getting an entry from the cache
Attributes: Attributes:
@ -43,10 +47,10 @@ class DictionaryEntry:
""" """
full = attr.ib(type=bool) full = attr.ib(type=bool)
known_absent = attr.ib() known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
value = attr.ib() value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
def __len__(self): def __len__(self) -> int:
return len(self.value) return len(self.value)
@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
sentinel = object() sentinel = object()
class DictionaryCache(Generic[KT, DKT]): class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key. fetching a subset of dictionary keys for a particular key.
""" """
@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
Args: Args:
key key
dict_key: If given a set of keys then return only those keys dict_keys: If given a set of keys then return only those keys
that exist in the cache. that exist in the cache.
Returns: Returns:
@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
self, self,
sequence: int, sequence: int,
key: KT, key: KT,
value: Dict[DKT, Any], value: Dict[DKT, DV],
fetched_keys: Optional[Set[DKT]] = None, fetched_keys: Optional[Set[DKT]] = None,
) -> None: ) -> None:
"""Updates the entry in the cache """Updates the entry in the cache
@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
self._update_or_insert(key, value, fetched_keys) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert( def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
) -> None: ) -> None:
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have
# changed # changed
entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value) entry.value.update(value)
entry.known_absent.update(known_absent) entry.known_absent.update(known_absent)
self.cache[key] = entry self.cache[key] = entry
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value) self.cache[key] = DictionaryEntry(True, known_absent, value)

View File

@ -35,6 +35,7 @@ from typing import (
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to # Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`. # `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None: if clock is None:
real_clock = Clock(reactor) real_clock = Clock(cast(IReactorTime, reactor))
else: else:
real_clock = clock real_clock = clock
@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
lock = threading.Lock() lock = threading.Lock()
def evict(): def evict() -> None:
while cache_len() > self.max_size: while cache_len() > self.max_size:
# Get the last node in the list (i.e. the oldest node). # Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node todelete = list_root.prev_node

View File

@ -195,7 +195,7 @@ class StreamChangeCache:
for entity in r: for entity in r:
del self._entity_to_key[entity] del self._entity_to_key[entity]
def _evict(self): def _evict(self) -> None:
while len(self._cache) > self._max_size: while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0) k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)

View File

@ -35,17 +35,17 @@ class TreeCache:
root = {key_1: {key_2: _value}} root = {key_1: {key_2: _value}}
""" """
def __init__(self): def __init__(self) -> None:
self.size = 0 self.size: int = 0
self.root = TreeCacheNode() self.root = TreeCacheNode()
def __setitem__(self, key, value): def __setitem__(self, key, value) -> None:
return self.set(key, value) self.set(key, value)
def __contains__(self, key): def __contains__(self, key) -> bool:
return self.get(key, SENTINEL) is not SENTINEL return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value): def set(self, key, value) -> None:
if isinstance(value, TreeCacheNode): if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value # this would mean we couldn't tell where our tree ended and the value
# started. # started.
@ -73,7 +73,7 @@ class TreeCache:
return default return default
return node.get(key[-1], default) return node.get(key[-1], default)
def clear(self): def clear(self) -> None:
self.size = 0 self.size = 0
self.root = TreeCacheNode() self.root = TreeCacheNode()
@ -128,7 +128,7 @@ class TreeCache:
def values(self): def values(self):
return iterate_tree_cache_entry(self.root) return iterate_tree_cache_entry(self.root)
def __len__(self): def __len__(self) -> int:
return self.size return self.size

View File

@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
signal.signal(signal.SIGTERM, sigterm) signal.signal(signal.SIGTERM, sigterm)
# Cleanup pid file at exit. # Cleanup pid file at exit.
def exit(): def exit() -> None:
logger.warning("Stopping daemon.") logger.warning("Stopping daemon.")
os.remove(pid_file) os.remove(pid_file)
sys.exit(0) sys.exit(0)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List
from twisted.internet import defer from twisted.internet import defer
@ -37,11 +38,11 @@ class Distributor:
model will do for today. model will do for today.
""" """
def __init__(self): def __init__(self) -> None:
self.signals = {} self.signals: Dict[str, Signal] = {}
self.pre_registration = {} self.pre_registration: Dict[str, List[Callable]] = {}
def declare(self, name): def declare(self, name: str) -> None:
if name in self.signals: if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name)) raise KeyError("%r already has a signal named %s" % (self, name))
@ -52,7 +53,7 @@ class Distributor:
for observer in self.pre_registration[name]: for observer in self.pre_registration[name]:
signal.observe(observer) signal.observe(observer)
def observe(self, name, observer): def observe(self, name: str, observer: Callable) -> None:
if name in self.signals: if name in self.signals:
self.signals[name].observe(observer) self.signals[name].observe(observer)
else: else:
@ -62,7 +63,7 @@ class Distributor:
self.pre_registration[name] = [] self.pre_registration[name] = []
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs): def fire(self, name: str, *args, **kwargs) -> None:
"""Dispatches the given signal to the registered observers. """Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred. Runs the observers as a background process. Does not return a deferred.
@ -83,18 +84,18 @@ class Signal:
method into all of the observers. method into all of the observers.
""" """
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name: str = name
self.observers = [] self.observers: List[Callable] = []
def observe(self, observer): def observe(self, observer: Callable) -> None:
"""Adds a new callable to the observer list which will be invoked by """Adds a new callable to the observer list which will be invoked by
the 'fire' method. the 'fire' method.
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers. not an error to fire a signal with no observers.

Some files were not shown because too many files have changed in this diff Show More