diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index f7a4ee7c13..67ccc03f6e 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -5,6 +5,9 @@ on: - cron: 0 8 * * * workflow_dispatch: + # NB: inputs are only present when this workflow is dispatched manually. + # (The default below is the default field value in the form to trigger + # a manual dispatch). Otherwise the inputs will evaluate to null. inputs: twisted_ref: description: Commit, branch or tag to checkout from upstream Twisted. @@ -49,7 +52,7 @@ jobs: extras: "all" - run: | poetry remove twisted - poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref }} + poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }} poetry install --no-interaction --extras "all test" - name: Remove warn_unused_ignores from mypy config run: sed '/warn_unused_ignores = True/d' -i mypy.ini diff --git a/CHANGES.md b/CHANGES.md index 95d8227ee0..666cd31ba0 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,8 @@ +# Synapse 1.90.0 (2023-08-15) + +No significant changes since 1.90.0rc1. + + # Synapse 1.90.0rc1 (2023-08-08) ### Features diff --git a/Cargo.lock b/Cargo.lock index 45e0f116e6..79d9cefcf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,9 +132,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" diff --git a/changelog.d/15870.feature b/changelog.d/15870.feature new file mode 100644 index 0000000000..527220d637 --- /dev/null +++ b/changelog.d/15870.feature @@ -0,0 +1 @@ +Implements an admin API to lock an user without deactivating them. Based on [MSC3939](https://github.com/matrix-org/matrix-spec-proposals/pull/3939). diff --git a/changelog.d/16010.misc b/changelog.d/16010.misc new file mode 100644 index 0000000000..1e1a148069 --- /dev/null +++ b/changelog.d/16010.misc @@ -0,0 +1 @@ +Update dehydrated devices implementation. diff --git a/changelog.d/16052.bugfix b/changelog.d/16052.bugfix new file mode 100644 index 0000000000..3c7a60f226 --- /dev/null +++ b/changelog.d/16052.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where concurrent requests to change a user's push rules could cause a deadlock. Contributed by Nick @ Beeper (@fizzadar). diff --git a/changelog.d/16061.misc b/changelog.d/16061.misc new file mode 100644 index 0000000000..37928b670f --- /dev/null +++ b/changelog.d/16061.misc @@ -0,0 +1 @@ +Fix database performance of read/write worker locks. diff --git a/changelog.d/16080.bugfix b/changelog.d/16080.bugfix new file mode 100644 index 0000000000..1ad6fb3c52 --- /dev/null +++ b/changelog.d/16080.bugfix @@ -0,0 +1 @@ +Fix a long-standing bu in `/sync` where timeout=0 does not skip caching, resulting in slow calls in cases where there are no new changes. Contributed by @PlasmaIntec. \ No newline at end of file diff --git a/changelog.d/16085.misc b/changelog.d/16085.misc new file mode 100644 index 0000000000..7b7a95edd4 --- /dev/null +++ b/changelog.d/16085.misc @@ -0,0 +1 @@ +Override global statement timeout when creating indexes in Postgres. diff --git a/changelog.d/16089.misc b/changelog.d/16089.misc new file mode 100644 index 0000000000..8c302e6884 --- /dev/null +++ b/changelog.d/16089.misc @@ -0,0 +1 @@ +Fix the type annotation on `run_db_interaction` in the Module API. \ No newline at end of file diff --git a/changelog.d/16091.doc b/changelog.d/16091.doc new file mode 100644 index 0000000000..a043df4efd --- /dev/null +++ b/changelog.d/16091.doc @@ -0,0 +1 @@ +Structured logging docs: add a link to explain the ELK stack diff --git a/changelog.d/16092.misc b/changelog.d/16092.misc new file mode 100644 index 0000000000..b520807771 --- /dev/null +++ b/changelog.d/16092.misc @@ -0,0 +1 @@ +Clean-up the presence code. diff --git a/changelog.d/16094.feature b/changelog.d/16094.feature new file mode 100644 index 0000000000..3be71badb9 --- /dev/null +++ b/changelog.d/16094.feature @@ -0,0 +1 @@ +Allow customising the IdP display name, icon, and brand for SAML and CAS providers (in addition to OIDC provider). diff --git a/changelog.d/16110.misc b/changelog.d/16110.misc new file mode 100644 index 0000000000..68efe86ddc --- /dev/null +++ b/changelog.d/16110.misc @@ -0,0 +1 @@ +Run `pyupgrade` for Python 3.8+. diff --git a/changelog.d/16112.misc b/changelog.d/16112.misc new file mode 100644 index 0000000000..05a58c1348 --- /dev/null +++ b/changelog.d/16112.misc @@ -0,0 +1 @@ +Rename pagination and purge locks and add comments to explain why they exist and how they work. diff --git a/changelog.d/16115.misc b/changelog.d/16115.misc new file mode 100644 index 0000000000..f325d2a31d --- /dev/null +++ b/changelog.d/16115.misc @@ -0,0 +1 @@ +Attempt to fix the twisted trunk job. diff --git a/changelog.d/16117.misc b/changelog.d/16117.misc new file mode 100644 index 0000000000..f33fa6dc17 --- /dev/null +++ b/changelog.d/16117.misc @@ -0,0 +1 @@ +Cache token introspection response from OIDC provider. diff --git a/changelog.d/16123.misc b/changelog.d/16123.misc new file mode 100644 index 0000000000..b7c6b7c2f2 --- /dev/null +++ b/changelog.d/16123.misc @@ -0,0 +1 @@ +Add cache to `get_server_keys_json_for_remote`. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 895b2a7af1..710fe25699 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path): global CONFIG_JSON CONFIG_JSON = config_path # bit cheeky, but just overwrite the global try: - with open(config_path, "r") as config: + with open(config_path) as config: syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] diff --git a/debian/changelog b/debian/changelog index ed35abc9ee..ad9a4b3c8c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.90.0) stable; urgency=medium + + * New Synapse release 1.90.0. + + -- Synapse Packaging team Tue, 15 Aug 2023 11:17:34 +0100 + matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium * New Synapse release 1.90.0rc1. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index dc824038b5..400a7515aa 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -861,7 +861,7 @@ def generate_worker_files( # Then a worker config file convert( "/conf/worker.yaml.j2", - "/conf/workers/{name}.yaml".format(name=worker_name), + f"/conf/workers/{worker_name}.yaml", **worker_config, worker_log_config_filepath=log_config_filepath, using_unix_sockets=using_unix_sockets, diff --git a/docker/start.py b/docker/start.py index ebcc599f04..aebc7e4aaa 100755 --- a/docker/start.py +++ b/docker/start.py @@ -82,7 +82,7 @@ def generate_config_from_template( with open(filename) as handle: value = handle.read() else: - log("Generating a random secret for {}".format(secret)) + log(f"Generating a random secret for {secret}") value = codecs.encode(os.urandom(32), "hex").decode() with open(filename, "w") as handle: handle.write(value) diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index ac4f635099..c269ce6af0 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -146,6 +146,7 @@ Body parameters: - `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator, granting them access to the Admin API, among other things. - `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged. +- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged. Note: the `password` field must also be set if both of the following are true: - `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user) diff --git a/docs/structured_logging.md b/docs/structured_logging.md index d43dc9eb6e..002565b223 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -3,7 +3,7 @@ A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software -such as the "ELK stack". +such as the [ELK stack](https://opensource.com/article/18/9/open-source-log-aggregation-tools). Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file should include a formatter which diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 2987c9332d..6601bba9f2 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3025,6 +3025,16 @@ enable SAML login. You can either put your entire pysaml config inline using the option, or you can specify a path to a psyaml config file with the sub-option `config_path`. This setting has the following sub-options: +* `idp_name`: A user-facing name for this identity provider, which is used to + offer the user a choice of login mechanisms. +* `idp_icon`: An optional icon for this identity provider, which is presented + by clients and Synapse's own IdP picker page. If given, must be an + MXC URI of the format `mxc:///`. (An easy way to + obtain such an MXC URI is to upload an image to an (unencrypted) room + and then copy the "url" from the source of the event.) +* `idp_brand`: An optional brand for this identity provider, allowing clients + to style the login flow according to the identity provider in question. + See the [spec](https://spec.matrix.org/latest/) for possible options here. * `sp_config`: the configuration for the pysaml2 Service Provider. See pysaml2 docs for format of config. Default values will be used for the `entityid` and `service` settings, so it is not normally necessary to specify them unless you need to @@ -3176,7 +3186,7 @@ Options for each entry include: * `idp_icon`: An optional icon for this identity provider, which is presented by clients and Synapse's own IdP picker page. If given, must be an - MXC URI of the format mxc:///. (An easy way to + MXC URI of the format `mxc:///`. (An easy way to obtain such an MXC URI is to upload an image to an (unencrypted) room and then copy the "url" from the source of the event.) @@ -3391,6 +3401,16 @@ Enable Central Authentication Service (CAS) for registration and login. Has the following sub-options: * `enabled`: Set this to true to enable authorization against a CAS server. Defaults to false. +* `idp_name`: A user-facing name for this identity provider, which is used to + offer the user a choice of login mechanisms. +* `idp_icon`: An optional icon for this identity provider, which is presented + by clients and Synapse's own IdP picker page. If given, must be an + MXC URI of the format `mxc:///`. (An easy way to + obtain such an MXC URI is to upload an image to an (unencrypted) room + and then copy the "url" from the source of the event.) +* `idp_brand`: An optional brand for this identity provider, allowing clients + to style the login flow according to the identity provider in question. + See the [spec](https://spec.matrix.org/latest/) for possible options here. * `server_url`: The URL of the CAS authorization endpoint. * `displayname_attribute`: The attribute of the CAS response to use as the display name. If no name is given here, no displayname will be set. @@ -3631,6 +3651,7 @@ This option has the following sub-options: * `prefer_local_users`: Defines whether to prefer local users in search query results. If set to true, local users are more likely to appear above remote users when searching the user directory. Defaults to false. +* `show_locked_users`: Defines whether to show locked users in search query results. Defaults to false. Example configuration: ```yaml @@ -3638,6 +3659,7 @@ user_directory: enabled: false search_all_users: true prefer_local_users: true + show_locked_users: true ``` --- ### `user_consent` diff --git a/mypy.ini b/mypy.ini index 1038b7d8c7..311a951aa8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -45,6 +45,13 @@ warn_unused_ignores = False disallow_untyped_defs = False disallow_incomplete_defs = False +[mypy-synapse.util.manhole] +# This module imports something from Twisted which has a bad annotation in Twisted trunk, +# but is unannotated in Twisted's latest release. We want to type-ignore the problem +# in the twisted trunk job, even though it has no effect on normal mypy runs. +warn_unused_ignores = False + + ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. ;; The `typeshed` project maintains stubs here: diff --git a/poetry.lock b/poetry.lock index 71b47a5805..db1332a04b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -589,13 +589,13 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.31" +version = "3.1.32" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"}, - {file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"}, + {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"}, + {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"}, ] [package.dependencies] @@ -887,17 +887,17 @@ scripts = ["click (>=6.0)", "twisted (>=16.4.0)"] [[package]] name = "isort" -version = "5.11.5" +version = "5.12.0" description = "A Python utility / library to sort Python imports." optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "isort-5.11.5-py3-none-any.whl", hash = "sha256:ba1d72fb2595a01c7895a5128f9585a5cc4b6d395f1c8d514989b9a7eb2a8746"}, - {file = "isort-5.11.5.tar.gz", hash = "sha256:6be1f76a507cb2ecf16c7cf14a37e41609ca082330be4e3436a18ef74add55db"}, + {file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"}, + {file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"}, ] [package.extras] -colors = ["colorama (>=0.4.3,<0.5.0)"] +colors = ["colorama (>=0.4.3)"] pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"] plugins = ["setuptools"] requirements-deprecated-finder = ["pip-api", "pipreqs"] @@ -2921,13 +2921,13 @@ files = [ [[package]] name = "txredisapi" -version = "1.4.9" +version = "1.4.10" description = "non-blocking redis client for python" optional = true python-versions = "*" files = [ - {file = "txredisapi-1.4.9-py3-none-any.whl", hash = "sha256:72e6ad09cc5fffe3bec2e55e5bfb74407bd357565fc212e6003f7e26ef7d8f78"}, - {file = "txredisapi-1.4.9.tar.gz", hash = "sha256:c9607062d05e4d0b8ef84719eb76a3fe7d5ccd606a2acf024429da51d6e84559"}, + {file = "txredisapi-1.4.10-py3-none-any.whl", hash = "sha256:0a6ea77f27f8cf092f907654f08302a97b48fa35f24e0ad99dfb74115f018161"}, + {file = "txredisapi-1.4.10.tar.gz", hash = "sha256:7609a6af6ff4619a3189c0adfb86aeda789afba69eb59fc1e19ac0199e725395"}, ] [package.dependencies] @@ -2936,13 +2936,13 @@ twisted = "*" [[package]] name = "types-bleach" -version = "6.0.0.3" +version = "6.0.0.4" description = "Typing stubs for bleach" optional = false python-versions = "*" files = [ - {file = "types-bleach-6.0.0.3.tar.gz", hash = "sha256:8ce7896d4f658c562768674ffcf07492c7730e128018f03edd163ff912bfadee"}, - {file = "types_bleach-6.0.0.3-py3-none-any.whl", hash = "sha256:d43eaf30a643ca824e16e2dcdb0c87ef9226237e2fa3ac4732a50cb3f32e145f"}, + {file = "types-bleach-6.0.0.4.tar.gz", hash = "sha256:357b0226f65c4f20ab3b13ca8d78a6b91c78aad256d8ec168d4e90fc3303ebd4"}, + {file = "types_bleach-6.0.0.4-py3-none-any.whl", hash = "sha256:2b8767eb407c286b7f02803678732e522e04db8d56cbc9f1270bee49627eae92"}, ] [[package]] @@ -2991,13 +2991,13 @@ files = [ [[package]] name = "types-pillow" -version = "10.0.0.1" +version = "10.0.0.2" description = "Typing stubs for Pillow" optional = false python-versions = "*" files = [ - {file = "types-Pillow-10.0.0.1.tar.gz", hash = "sha256:834a07a04504f8bf37936679bc6a5802945e7644d0727460c0c4d4307967e2a3"}, - {file = "types_Pillow-10.0.0.1-py3-none-any.whl", hash = "sha256:be576b67418f1cb3b93794cf7946581be1009a33a10085b3c132eb0875a819b4"}, + {file = "types-Pillow-10.0.0.2.tar.gz", hash = "sha256:fe09380ab22d412ced989a067e9ee4af719fa3a47ba1b53b232b46514a871042"}, + {file = "types_Pillow-10.0.0.2-py3-none-any.whl", hash = "sha256:29d51a3ce6ef51fabf728a504d33b4836187ff14256b2e86996d55c91ab214b1"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index ca532e2c7c..86680cb8e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml" [tool.poetry] name = "matrix-synapse" -version = "1.90.0rc1" +version = "1.90.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "Apache-2.0" diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py index bb89ba581c..c03e3418c0 100755 --- a/scripts-dev/build_debian_packages.py +++ b/scripts-dev/build_debian_packages.py @@ -47,7 +47,7 @@ can be passed on the commandline for debugging. projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -class Builder(object): +class Builder: def __init__( self, redirect_stdout: bool = False, diff --git a/scripts-dev/check_schema_delta.py b/scripts-dev/check_schema_delta.py index fee4a8bd3d..467be96fdf 100755 --- a/scripts-dev/check_schema_delta.py +++ b/scripts-dev/check_schema_delta.py @@ -43,7 +43,7 @@ def main(force_colors: bool) -> None: diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) # Get the schema version of the local file to check against current schema on develop - with open("synapse/storage/schema/__init__.py", "r") as file: + with open("synapse/storage/schema/__init__.py") as file: local_schema = file.read() new_locals: Dict[str, Any] = {} exec(local_schema, new_locals) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 63f0b25ddd..5ad334b4d8 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -247,7 +247,7 @@ def main() -> None: def read_args_from_config(args: argparse.Namespace) -> None: - with open(args.config, "r") as fh: + with open(args.config) as fh: config = yaml.safe_load(fh) if not args.server_name: diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 89ffba8d92..4ac8eaa889 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/scripts-dev/sign_json.py b/scripts-dev/sign_json.py index bb217799fb..00cbaf68f5 100755 --- a/scripts-dev/sign_json.py +++ b/scripts-dev/sign_json.py @@ -145,7 +145,7 @@ Example usage: def read_args_from_config(args: argparse.Namespace) -> None: - with open(args.config, "r") as fh: + with open(args.config) as fh: config = yaml.safe_load(fh) if not args.server_name: args.server_name = config["server_name"] diff --git a/synapse/__init__.py b/synapse/__init__.py index 6c1801862b..2f9c22a833 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool # Check that we're not running on an unsupported Python version. -if sys.version_info < (3, 8): +# +# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the +# if-statement completely. +py_version = sys.version_info +if py_version < (3, 8): print("Synapse requires Python 3.8 or above.") sys.exit(1) @@ -78,7 +82,7 @@ try: except ImportError: pass -import synapse.util +import synapse.util # noqa: E402 __version__ = synapse.util.SYNAPSE_VERSION diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 22c84fbd5b..49242800b8 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -123,7 +123,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "rooms": ["is_public", "has_auth_chain_index"], - "users": ["shadow_banned", "approved"], + "users": ["shadow_banned", "approved", "locked"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], "per_user_experimental_features": ["enabled"], @@ -1205,10 +1205,10 @@ class CursesProgress(Progress): self.total_processed = 0 self.total_remaining = 0 - super(CursesProgress, self).__init__() + super().__init__() def update(self, table: str, num_done: int) -> None: - super(CursesProgress, self).update(table, num_done) + super().update(table, num_done) self.total_processed = 0 self.total_remaining = 0 @@ -1304,7 +1304,7 @@ class TerminalProgress(Progress): """Just prints progress to the terminal""" def update(self, table: str, num_done: int) -> None: - super(TerminalProgress, self).update(table, num_done) + super().update(table, num_done) data = self.tables[table] diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index 0adf94bba6..f97aecf8d5 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -38,7 +38,7 @@ class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore # type: ignore [assignment] def __init__(self, config: HomeServerConfig): - super(MockHomeserver, self).__init__( + super().__init__( hostname=config.server.server_name, config=config, reactor=reactor, diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py index 90cfe39d76..bb3f50f2dd 100644 --- a/synapse/api/auth/__init__.py +++ b/synapse/api/auth/__init__.py @@ -60,6 +60,7 @@ class Auth(Protocol): request: SynapseRequest, allow_guest: bool = False, allow_expired: bool = False, + allow_locked: bool = False, ) -> Requester: """Get a registered user's ID. diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py index e2ae198b19..6a5fd44ec0 100644 --- a/synapse/api/auth/internal.py +++ b/synapse/api/auth/internal.py @@ -58,6 +58,7 @@ class InternalAuth(BaseAuth): request: SynapseRequest, allow_guest: bool = False, allow_expired: bool = False, + allow_locked: bool = False, ) -> Requester: """Get a registered user's ID. @@ -79,7 +80,7 @@ class InternalAuth(BaseAuth): parent_span = active_span() with start_active_span("get_user_by_req"): requester = await self._wrapped_get_user_by_req( - request, allow_guest, allow_expired + request, allow_guest, allow_expired, allow_locked ) if parent_span: @@ -107,6 +108,7 @@ class InternalAuth(BaseAuth): request: SynapseRequest, allow_guest: bool, allow_expired: bool, + allow_locked: bool, ) -> Requester: """Helper for get_user_by_req @@ -126,6 +128,17 @@ class InternalAuth(BaseAuth): access_token, allow_expired=allow_expired ) + # Deny the request if the user account is locked. + if not allow_locked and await self.store.get_user_locked_status( + requester.user.to_string() + ): + raise AuthError( + 401, + "User account has been locked", + errcode=Codes.USER_LOCKED, + additional_fields={"soft_logout": True}, + ) + # Deny the request if the user account has expired. # This check is only done for regular users, not appservice ones. if not allow_expired: diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index bd4fc9c0ee..3a516093f5 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -27,6 +27,7 @@ from twisted.web.http_headers import Headers from synapse.api.auth.base import BaseAuth from synapse.api.errors import ( AuthError, + Codes, HttpResponseException, InvalidClientTokenError, OAuthInsufficientScopeError, @@ -38,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.types import Requester, UserID, create_requester from synapse.util import json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.caches.expiringcache import ExpiringCache if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth): self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) + self._clock = hs.get_clock() + self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache( + cache_name="introspection_token_cache", + clock=self._clock, + max_len=10000, + expiry_ms=5 * 60 * 1000, + ) + if isinstance(auth_method, PrivateKeyJWTWithKid): # Use the JWK as the client secret when using the private_key_jwt method assert self._config.jwk, "No JWK provided" @@ -143,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth): Returns: The introspection response """ + # check the cache before doing a request + introspection_token = self._token_cache.get(token, None) + + if introspection_token: + # check the expiration field of the token (if it exists) + exp = introspection_token.get("exp", None) + if exp: + time_now = self._clock.time() + expired = time_now > exp + if not expired: + return introspection_token + else: + return introspection_token + metadata = await self._issuer_metadata.get() introspection_endpoint = metadata.get("introspection_endpoint") raw_headers: Dict[str, str] = { @@ -156,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth): # Fill the body/headers with credentials uri, raw_headers, body = self._client_auth.prepare( - method="POST", uri=introspection_endpoint, headers=raw_headers, body=body + method="POST", + uri=introspection_endpoint, + headers=raw_headers, + body=body, ) headers = Headers({k: [v] for (k, v) in raw_headers.items()}) @@ -186,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth): "The introspection endpoint returned an invalid JSON response." ) - return IntrospectionToken(**resp) + expiration = resp.get("exp", None) + if expiration: + if self._clock.time() > expiration: + raise InvalidClientTokenError("Token is expired.") + + introspection_token = IntrospectionToken(**resp) + + # add token to cache + self._token_cache[token] = introspection_token + + return introspection_token async def is_server_admin(self, requester: Requester) -> bool: return "urn:synapse:admin:*" in requester.scope @@ -196,6 +233,7 @@ class MSC3861DelegatedAuth(BaseAuth): request: SynapseRequest, allow_guest: bool = False, allow_expired: bool = False, + allow_locked: bool = False, ) -> Requester: access_token = self.get_access_token_from_request(request) @@ -205,6 +243,17 @@ class MSC3861DelegatedAuth(BaseAuth): # so that we don't provision the user if they don't have enough permission: requester = await self.get_user_by_access_token(access_token, allow_expired) + # Deny the request if the user account is locked. + if not allow_locked and await self.store.get_user_locked_status( + requester.user.to_string() + ): + raise AuthError( + 401, + "User account has been locked", + errcode=Codes.USER_LOCKED, + additional_fields={"soft_logout": True}, + ) + if not allow_guest and requester.is_guest: raise OAuthInsufficientScopeError([SCOPE_MATRIX_API]) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index dc32553d0c..bf311b636d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -18,8 +18,7 @@ """Contains constants from the specification.""" import enum - -from typing_extensions import Final +from typing import Final # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 3546aaf7c3..7ffd72c42c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -80,6 +80,8 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + # USER_LOCKED = "M_USER_LOCKED" + USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" # Part of MSC3848 # https://github.com/matrix-org/matrix-spec-proposals/pull/3848 diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 9152c06bd6..c4e63e7411 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -47,6 +47,10 @@ class CasConfig(Config): required_attributes ) + self.idp_name = cas_config.get("idp_name", "CAS") + self.idp_icon = cas_config.get("idp_icon") + self.idp_brand = cas_config.get("idp_brand") + else: self.cas_server_url = None self.cas_service_url = None diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 49ca663dde..c69e24cf26 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -89,8 +89,14 @@ class SAML2Config(Config): "grandfathered_mxid_source_attribute", "uid" ) + # refers to a SAML IdP entity ID self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) + # IdP properties for Matrix clients + self.idp_name = saml2_config.get("idp_name", "SAML") + self.idp_icon = saml2_config.get("idp_icon") + self.idp_brand = saml2_config.get("idp_brand") + # user_mapping_provider may be None if the key is present but has no value ump_dict = saml2_config.get("user_mapping_provider") or {} diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index c9e18b91e9..f60ec2ea66 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -35,3 +35,4 @@ class UserDirectoryConfig(Config): self.user_directory_search_prefer_local_users = user_directory_config.get( "prefer_local_users", False ) + self.show_locked_users = user_directory_config.get("show_locked_users", False) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index a90d99c4d6..f9915e5a3f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -63,7 +63,7 @@ from synapse.federation.federation_base import ( ) from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -1245,7 +1245,7 @@ class FederationServer(FederationBase): # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME` # lock. async with self._worker_lock_handler.acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): await self._federation_event_handler.on_receive_pdu( origin, event diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 119c7f8384..0e812a6d8b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -67,6 +67,7 @@ class AdminHandler: "name", "admin", "deactivated", + "locked", "shadow_banned", "creation_ts", "appservice_id", diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index fc467bc7c1..5c71637038 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -76,12 +76,13 @@ class CasHandler: self.idp_id = "cas" # user-facing name of this auth provider - self.idp_name = "CAS" + self.idp_name = hs.config.cas.idp_name - # we do not currently support brands/icons for CAS auth, but this is required by - # the SsoIdentityProvider protocol type. - self.idp_icon = None - self.idp_brand = None + # MXC URI for icon for this auth provider + self.idp_icon = hs.config.cas.idp_icon + + # optional brand identifier for this auth provider + self.idp_brand = hs.config.cas.idp_brand self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b7bf70a72d..5ae427d52c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender = hs.get_federation_sender() self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() + self.db_pool = hs.get_datastores().main.db_pool self.device_list_updater = DeviceListUpdater(hs, self) @@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler): device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, + keys_for_device: Optional[JsonDict] = None, ) -> str: - """Store a dehydrated device for a user. If the user had a previous - dehydrated device, it is removed. + """Store a dehydrated device for a user, optionally storing the keys associated with + it as well. If the user had a previous dehydrated device, it is removed. Args: user_id: the user that we are storing the device for device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device + keys_for_device: keys for the dehydrated device Returns: device id of the dehydrated device """ @@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler): device_id, initial_device_display_name, ) + + time_now = self.clock.time_msec() + old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data + user_id, device_id, device_data, time_now, keys_for_device ) + if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) + return device_id async def rehydrate_device( diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 15e94a03cb..17ff8821d9 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -367,19 +367,6 @@ class DeviceMessageHandler: errcode=Codes.INVALID_PARAM, ) - # if we have a since token, delete any to-device messages before that token - # (since we now know that the device has received them) - deleted = await self.store.delete_messages_for_device( - user_id, device_id, since_stream_id - ) - logger.debug( - "Deleted %d to-device messages up to %d for user_id %s device_id %s", - deleted, - since_stream_id, - user_id, - device_id, - ) - to_token = self.event_sources.get_current_token().to_device_key messages, stream_id = await self.store.get_messages_for_device( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d485f21e49..a74db1dccf 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler -from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -1034,7 +1034,7 @@ class EventCreationHandler: ) async with self._worker_lock_handler.acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): return await self._create_and_send_nonmember_event_locked( requester=requester, @@ -1978,7 +1978,7 @@ class EventCreationHandler: for room_id in room_ids: async with self._worker_lock_handler.acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): dummy_event_sent = await self._send_dummy_event_for_room(room_id) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index da34658470..1be6ebc6d9 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig from synapse.handlers.room import ShutdownRoomResponse +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin @@ -46,9 +47,10 @@ logger = logging.getLogger(__name__) BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 -PURGE_HISTORY_LOCK_NAME = "purge_history_lock" - -DELETE_ROOM_LOCK_NAME = "delete_room_lock" +# This is used to avoid purging a room several time at the same moment, +# and also paginating during a purge. Pagination can trigger backfill, +# which would create old events locally, and would potentially clash with the room delete. +PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock" @attr.s(slots=True, auto_attribs=True) @@ -363,7 +365,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: async with self._worker_locks.acquire_read_write_lock( - PURGE_HISTORY_LOCK_NAME, room_id, write=True + PURGE_PAGINATION_LOCK_NAME, room_id, write=True ): await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events @@ -421,7 +423,10 @@ class PaginationHandler: force: set true to skip checking for joined users. """ async with self._worker_locks.acquire_multi_read_write_lock( - [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)], + [ + (PURGE_PAGINATION_LOCK_NAME, room_id), + (NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id), + ], write=True, ): # first check that we have no users in this room @@ -483,7 +488,7 @@ class PaginationHandler: room_token = from_token.room_key async with self._worker_locks.acquire_read_write_lock( - PURGE_HISTORY_LOCK_NAME, room_id, write=False + PURGE_PAGINATION_LOCK_NAME, room_id, write=False ): (membership, member_event_id) = (None, None) if not use_admin_priviledge: @@ -761,7 +766,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: async with self._worker_locks.acquire_read_write_lock( - PURGE_HISTORY_LOCK_NAME, room_id, write=True + PURGE_PAGINATION_LOCK_NAME, room_id, write=True ): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index cd7df0525f..e8e9db4b91 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -30,9 +30,9 @@ from types import TracebackType from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Collection, + ContextManager, Dict, Generator, Iterable, @@ -44,7 +44,6 @@ from typing import ( ) from prometheus_client import Counter -from typing_extensions import ContextManager import synapse.metrics from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState @@ -54,7 +53,10 @@ from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.http.presence import ( ReplicationBumpPresenceActiveTime, ReplicationPresenceSetState, @@ -141,6 +143,8 @@ class BasePresenceHandler(abc.ABC): self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id + self._presence_enabled = hs.config.server.use_presence + self._federation = None if hs.should_send_federation(): self._federation = hs.get_federation_sender() @@ -149,6 +153,15 @@ class BasePresenceHandler(abc.ABC): self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + self.VALID_PRESENCE: Tuple[str, ...] = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + + if self._busy_presence_enabled: + self.VALID_PRESENCE += (PresenceState.BUSY,) + active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -395,8 +408,6 @@ class WorkerPresenceHandler(BasePresenceHandler): self._presence_writer_instance = hs.config.worker.writers.presence[0] - self._presence_enabled = hs.config.server.use_presence - # Route presence EDUs to the right worker hs.get_federation_registry().register_instances_for_edu( EduTypes.PRESENCE, @@ -421,8 +432,6 @@ class WorkerPresenceHandler(BasePresenceHandler): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self._busy_presence_enabled = hs.config.experimental.msc3026_enabled - hs.get_reactor().addSystemEventTrigger( "before", "shutdown", @@ -490,7 +499,9 @@ class WorkerPresenceHandler(BasePresenceHandler): # what the spec wants: see comment in the BasePresenceHandler version # of this function. await self.set_state( - UserID.from_string(user_id), {"presence": presence_state}, True + UserID.from_string(user_id), + {"presence": presence_state}, + ignore_status_msg=True, ) curr_sync = self._user_to_num_current_syncs.get(user_id, 0) @@ -601,22 +612,13 @@ class WorkerPresenceHandler(BasePresenceHandler): """ presence = state["presence"] - valid_presence = ( - PresenceState.ONLINE, - PresenceState.UNAVAILABLE, - PresenceState.OFFLINE, - PresenceState.BUSY, - ) - - if presence not in valid_presence or ( - presence == PresenceState.BUSY and not self._busy_presence_enabled - ): + if presence not in self.VALID_PRESENCE: raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return # Proxy request to instance that writes presence @@ -633,7 +635,7 @@ class WorkerPresenceHandler(BasePresenceHandler): with the app. """ # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return # Proxy request to instance that writes presence @@ -649,7 +651,6 @@ class PresenceHandler(BasePresenceHandler): self.hs = hs self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() - self._presence_enabled = hs.config.server.use_presence federation_registry = hs.get_federation_registry() @@ -700,8 +701,6 @@ class PresenceHandler(BasePresenceHandler): self._on_shutdown, ) - self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. self.user_to_num_current_syncs: Dict[str, int] = {} @@ -723,21 +722,16 @@ class PresenceHandler(BasePresenceHandler): # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. - def run_timeout_handler() -> Awaitable[None]: - return run_as_background_process( - "handle_presence_timeouts", self._handle_timeouts - ) - self.clock.call_later( - 30, self.clock.looping_call, run_timeout_handler, 5000 + 30, self.clock.looping_call, self._handle_timeouts, 5000 ) - def run_persister() -> Awaitable[None]: - return run_as_background_process( - "persist_presence_changes", self._persist_unpersisted_changes - ) - - self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) LaterGauge( "synapse_handlers_presence_wheel_timer_size", @@ -783,6 +777,7 @@ class PresenceHandler(BasePresenceHandler): ) logger.info("Finished _on_shutdown") + @wrap_as_background_process("persist_presence_changes") async def _persist_unpersisted_changes(self) -> None: """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. @@ -898,6 +893,7 @@ class PresenceHandler(BasePresenceHandler): states, [destination] ) + @wrap_as_background_process("handle_presence_timeouts") async def _handle_timeouts(self) -> None: """Checks the presence of users that have timed out and updates as appropriate. @@ -955,7 +951,7 @@ class PresenceHandler(BasePresenceHandler): with the app. """ # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return user_id = user.to_string() @@ -990,56 +986,51 @@ class PresenceHandler(BasePresenceHandler): client that is being used by a user. presence_state: The presence state indicated in the sync request """ - # Override if it should affect the user's presence, if presence is - # disabled. - if not self.hs.config.server.use_presence: - affect_presence = False + if not affect_presence or not self._presence_enabled: + return _NullContextManager() - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + prev_state = await self.current_state_for_user(user_id) + + # If they're busy then they don't stop being busy just by syncing, + # so just update the last sync time. + if prev_state.state != PresenceState.BUSY: + # XXX: We set_state separately here and just update the last_active_ts above + # This keeps the logic as similar as possible between the worker and single + # process modes. Using set_state will actually cause last_active_ts to be + # updated always, which is not what the spec calls for, but synapse has done + # this for... forever, I think. + await self.set_state( + UserID.from_string(user_id), + {"presence": presence_state}, + ignore_status_msg=True, + ) + # Retrieve the new state for the logic below. This should come from the + # in-memory cache. prev_state = await self.current_state_for_user(user_id) - # If they're busy then they don't stop being busy just by syncing, - # so just update the last sync time. - if prev_state.state != PresenceState.BUSY: - # XXX: We set_state separately here and just update the last_active_ts above - # This keeps the logic as similar as possible between the worker and single - # process modes. Using set_state will actually cause last_active_ts to be - # updated always, which is not what the spec calls for, but synapse has done - # this for... forever, I think. - await self.set_state( - UserID.from_string(user_id), {"presence": presence_state}, True - ) - # Retrieve the new state for the logic below. This should come from the - # in-memory cache. - prev_state = await self.current_state_for_user(user_id) - - # To keep the single process behaviour consistent with worker mode, run the - # same logic as `update_external_syncs_row`, even though it looks weird. - if prev_state.state == PresenceState.OFFLINE: - await self._update_states( - [ - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - ) - ] - ) - # otherwise, set the new presence state & update the last sync time, - # but don't update last_active_ts as this isn't an indication that - # they've been active (even though it's probably been updated by - # set_state above) - else: - await self._update_states( - [ - prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec() - ) - ] - ) + # To keep the single process behaviour consistent with worker mode, run the + # same logic as `update_external_syncs_row`, even though it looks weird. + if prev_state.state == PresenceState.OFFLINE: + await self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) + # otherwise, set the new presence state & update the last sync time, + # but don't update last_active_ts as this isn't an indication that + # they've been active (even though it's probably been updated by + # set_state above) + else: + await self._update_states( + [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())] + ) async def _end() -> None: try: @@ -1061,8 +1052,7 @@ class PresenceHandler(BasePresenceHandler): try: yield finally: - if affect_presence: - run_in_background(_end) + run_in_background(_end) return _user_syncing() @@ -1229,20 +1219,11 @@ class PresenceHandler(BasePresenceHandler): status_msg = state.get("status_msg", None) presence = state["presence"] - valid_presence = ( - PresenceState.ONLINE, - PresenceState.UNAVAILABLE, - PresenceState.OFFLINE, - PresenceState.BUSY, - ) - - if presence not in valid_presence or ( - presence == PresenceState.BUSY and not self._busy_presence_enabled - ): + if presence not in self.VALID_PRESENCE: raise SynapseError(400, "Invalid presence state") # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return user_id = target_user.to_string() diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index bd8277e736..bb409f97b7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -39,7 +39,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler -from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process @@ -629,7 +629,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async with self.member_linearizer.queue(key): async with self._worker_lock_handler.acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): diff = self.clock.time_msec() - then diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 6083c9f4b5..d00035c332 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -74,12 +74,13 @@ class SamlHandler: self.idp_id = "saml" # user-facing name of this auth provider - self.idp_name = "SAML" + self.idp_name = hs.config.saml2.idp_name - # we do not currently support icons/brands for SAML auth, but this is required by - # the SsoIdentityProvider protocol type. - self.idp_icon = None - self.idp_brand = None + # MXC URI for icon for this auth provider + self.idp_icon = hs.config.saml2.idp_icon + + # optional brand identifier for this auth provider + self.idp_brand = hs.config.saml2.idp_brand # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 4d29328a74..e9a544e754 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -24,13 +24,14 @@ from typing import ( Iterable, List, Mapping, + NoReturn, Optional, Set, ) from urllib.parse import urlencode import attr -from typing_extensions import NoReturn, Protocol +from typing_extensions import Protocol from twisted.web.iweb import IRequest from twisted.web.server import Request @@ -791,7 +792,7 @@ class SsoHandler: if code != 200: raise Exception( - "GET request to download sso avatar image returned {}".format(code) + f"GET request to download sso avatar image returned {code}" ) # upload name includes hash of the image file's content so that we can diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 7cabf7980a..3dde19fc81 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -14,9 +14,15 @@ # limitations under the License. import logging from collections import Counter -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple - -from typing_extensions import Counter as CounterType +from typing import ( + TYPE_CHECKING, + Any, + Counter as CounterType, + Dict, + Iterable, + Optional, + Tuple, +) from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import event_processing_positions diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c010405be6..60a9f341b5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -387,16 +387,16 @@ class SyncHandler: from_token=since_token, ) - # if nothing has happened in any of the users' rooms since /sync was called, - # the resultant next_batch will be the same as since_token (since the result - # is generated when wait_for_events is first called, and not regenerated - # when wait_for_events times out). - # - # If that happens, we mustn't cache it, so that when the client comes back - # with the same cache token, we don't immediately return the same empty - # result, causing a tightloop. (#8518) - if result.next_batch == since_token: - cache_context.should_cache = False + # if nothing has happened in any of the users' rooms since /sync was called, + # the resultant next_batch will be the same as since_token (since the result + # is generated when wait_for_events is first called, and not regenerated + # when wait_for_events times out). + # + # If that happens, we mustn't cache it, so that when the client comes back + # with the same cache token, we don't immediately return the same empty + # result, causing a tightloop. (#8518) + if result.next_batch == since_token: + cache_context.should_cache = False if result: if sync_config.filter_collection.lazy_load_members(): @@ -1442,11 +1442,9 @@ class SyncHandler: # Now we have our list of joined room IDs, exclude as configured and freeze joined_room_ids = frozenset( - ( - room_id - for room_id in mutable_joined_room_ids - if room_id not in mutable_rooms_to_exclude - ) + room_id + for room_id in mutable_joined_room_ids + if room_id not in mutable_rooms_to_exclude ) logger.debug( diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 05197edc95..a0f5568000 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.worker.should_update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users + self.show_locked_users = hs.config.userdirectory.show_locked_users self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._hs = hs @@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler): ] } """ - results = await self.store.search_user_dir(user_id, search_term, limit) + results = await self.store.search_user_dir( + user_id, search_term, limit, self.show_locked_users + ) # Remove any spammy users from the results. non_spammy_users = [] diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 72df773a86..58efe7116b 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -42,7 +42,11 @@ if TYPE_CHECKING: from synapse.server import HomeServer -DELETE_ROOM_LOCK_NAME = "delete_room_lock" +# This lock is used to avoid creating an event while we are purging the room. +# We take a read lock when creating an event, and a write one when purging a room. +# This is because it is fine to create several events concurrently, since referenced events +# will not disappear under our feet as long as we don't delete the room. +NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock" class WorkerLocksHandler: diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 5a61b21eaf..284fbac524 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -18,10 +18,9 @@ import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import Callable, Optional +from typing import Callable, Deque, Optional import attr -from typing_extensions import Deque from zope.interface import implementer from twisted.application.internet import ClientService diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index acee1dafd3..9ad8e038ae 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,7 +31,7 @@ from typing import ( import attr import jinja2 -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall @@ -885,7 +885,7 @@ class ModuleApi: def run_db_interaction( self, desc: str, - func: Callable[P, T], + func: Callable[Concatenate[LoggingTransaction, P], T], *args: P.args, **kwargs: P.kwargs, ) -> "defer.Deferred[T]": diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py index e191450323..32db7cce8d 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py @@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks: generally discouraged as it doesn't support internationalization. """ for callback in self._check_event_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(event)) if res is False or res == self.NOT_SPAM: # This spam-checker accepts the event. @@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks: True if the event should be silently dropped """ for callback in self._should_drop_federated_event_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res: Union[bool, str] = await delay_cancellation(callback(event)) if res: return res @@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. """ for callback in self._user_may_join_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(user_id, room_id, is_invited)) # Normalize return values to `Codes` or `"NOT_SPAM"`. if res is True or res is self.NOT_SPAM: @@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, Codes otherwise. """ for callback in self._user_may_invite_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation( callback(inviter_userid, invitee_userid, room_id) ) @@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, Codes otherwise. """ for callback in self._user_may_send_3pid_invite_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation( callback(inviter_userid, medium, address, room_id) ) @@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks: userid: The ID of the user attempting to create a room """ for callback in self._user_may_create_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid)) if res is True or res is self.NOT_SPAM: continue @@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._user_may_create_room_alias_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid, room_alias)) if res is True or res is self.NOT_SPAM: continue @@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks: room_id: The ID of the room that would be published """ for callback in self._user_may_publish_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid, room_id)) if res is True or res is self.NOT_SPAM: continue @@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): # Make a copy of the user profile object to ensure the spam checker cannot # modify it. res = await delay_cancellation(callback(user_profile.copy())) @@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._check_registration_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): behaviour = await delay_cancellation( callback(email_threepid, username, request_info, auth_provider_id) ) @@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._check_media_file_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(file_wrapper, file_info)) # Normalize return values to `Codes` or `"NOT_SPAM"`. if res is False or res is self.NOT_SPAM: @@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._check_login_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation( callback( user_id, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index a2cabba7b1..38adcbe1d0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Deque, Dict, Iterable, Iterator, @@ -29,7 +30,6 @@ from typing import ( ) from prometheus_client import Counter -from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e0257daa75..04d9ef25b7 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -280,6 +280,17 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" ) + lock = body.get("locked", False) + if not isinstance(lock, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean" + ) + + if deactivate and lock: + raise SynapseError( + HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked" + ) + approved: Optional[bool] = None if "approved" in body and self._msc3866_enabled: approved = body["approved"] @@ -397,6 +408,12 @@ class UserRestServletV2(RestServlet): target_user.to_string() ) + if "locked" in body: + if lock and not user["locked"]: + await self.store.set_user_locked_status(user_id, True) + elif not lock and user["locked"]: + await self.store.set_user_locked_status(user_id, False) + if "user_type" in body: await self.store.set_user_type(target_user, user_type) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 51f17f80da..925f037743 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -29,7 +29,6 @@ from synapse.http.servlet import ( parse_integer, ) from synapse.http.site import SynapseRequest -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client.models import AuthenticationData from synapse.rest.models import RequestBodyModel @@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = handler - if hs.config.worker.worker_app is None: - # if main process - self.key_uploader = self.e2e_keys_handler.upload_keys_for_user - else: - # then a worker - self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet): "Device key(s) not found, these must be provided.", ) - # TODO: Those two operations, creating a device and storing the - # device's keys should be atomic. device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), submission.device_id, submission.device_data.dict(), submission.initial_device_display_name, - ) - - # TODO: Do we need to do something with the result here? - await self.key_uploader( - user_id=user_id, device_id=submission.device_id, keys=submission.dict() + device_info, ) return 200, {"device_id": device_id} diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 94ad90942f..2e104d4888 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet): self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req( + request, allow_expired=True, allow_locked=True + ) if requester.device_id is None: # The access token wasn't associated with a device. @@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet): self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req( + request, allow_expired=True, allow_locked=True + ) user_id = requester.user.to_string() # first delete all of the user's devices diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 5c9fece3ba..5ed3b83a03 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.types import JsonDict +from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: from synapse.server import HomeServer @@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None self._push_rules_handler = hs.get_push_rules_handler() + self._push_rule_linearizer = Linearizer(name="push_rules") async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + async with self._push_rule_linearizer.queue(user_id): + return await self.handle_put(request, path, user_id) + + async def handle_put( + self, request: SynapseRequest, path: str, user_id: str + ) -> Tuple[int, JsonDict]: spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = await self.auth.get_user_by_req(request) - if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) - user_id = requester.user.to_string() - if spec.attr: try: await self._push_rules_handler.set_rule_attr(user_id, spec, content) @@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path(path.split("/")) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() + async with self._push_rule_linearizer.queue(user_id): + return await self.handle_delete(request, path, user_id) + + async def handle_delete( + self, + request: SynapseRequest, + path: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + spec = _rule_spec_from_path(path.split("/")) + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" try: diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 4a5d9e13e7..b1f6b5d1b7 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet): try: async with self._worker_lock_handler.acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8f3865d412..981fd1f58a 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple from signedjson.sign import sign_json @@ -27,6 +27,7 @@ from synapse.http.servlet import ( parse_integer, parse_json_object_from_request, ) +from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -157,14 +158,22 @@ class RemoteKey(RestServlet): ) -> JsonDict: logger.info("Handling query for keys %r", query) - store_queries = [] + server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): - if not key_ids: - key_ids = (None,) - for key_id in key_ids: - store_queries.append((server_name, key_id, None)) + if key_ids: + results: Mapping[ + str, Optional[FetchKeyResultForRemote] + ] = await self.store.get_server_keys_json_for_remote( + server_name, key_ids + ) + else: + results = await self.store.get_all_server_keys_json_for_remote( + server_name + ) - cached = await self.store.get_server_keys_json_for_remote(store_queries) + server_keys.update( + ((server_name, key_id), res) for key_id, res in results.items() + ) json_results: Set[bytes] = set() @@ -173,23 +182,20 @@ class RemoteKey(RestServlet): # Map server_name->key_id->int. Note that the value of the int is unused. # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id, _), key_results in cached.items(): - results = [(result["ts_added_ms"], result) for result in key_results] - - if key_id is None: + for (server_name, key_id), key_result in server_keys.items(): + if not query[server_name]: # all keys were requested. Just return what we have without worrying # about validity - for _, result in results: - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(result["key_json"])) + if key_result: + json_results.add(key_result.key_json) continue miss = False - if not results: + if key_result is None: miss = True else: - ts_added_ms, most_recent_result = max(results) - ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] + ts_added_ms = key_result.added_ts + ts_valid_until_ms = key_result.valid_until_ts req_key = query.get(server_name, {}).get(key_id, {}) req_valid_until = req_key.get("minimum_valid_until_ts") if req_valid_until is not None: @@ -235,8 +241,8 @@ class RemoteKey(RestServlet): ts_valid_until_ms, time_now_ms, ) - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(most_recent_result["key_json"])) + + json_results.add(key_result.key_json) if miss and query_remote_on_cache_miss: # only bother attempting to fetch keys from servers on our whitelist diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 2d5ddc3e7b..ddca0af1da 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -238,6 +238,7 @@ class BackgroundUpdater: def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database + self.hs = hs self._database_name = database.name() @@ -758,6 +759,11 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) + # override the global statement timeout to avoid accidentally squashing + # a long-running index creation process + timeout_sql = "SET SESSION statement_timeout = 0" + c.execute(timeout_sql) + sql = ( "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" " ON %(table)s" @@ -778,6 +784,12 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) finally: + # mypy ignore - `statement_timeout` is defined on PostgresEngine + # reset the global timeout to the default + default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined] + undo_timeout_sql = f"SET statement_timeout = {default_timeout}" + conn.cursor().execute(undo_timeout_sql) + conn.set_session(autocommit=False) # type: ignore def create_index_sqlite(conn: Connection) -> None: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 35cd1089d6..abd1d149db 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -45,7 +45,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.opentracing import ( SynapseTags, @@ -357,7 +357,7 @@ class EventsPersistenceStorageController: # it. We might already have taken out the lock, but since this is just a # "read" lock its inherently reentrant. async with self.hs.get_worker_locks_handler().acquire_read_write_lock( - DELETE_ROOM_LOCK_NAME, room_id, write=False + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): if isinstance(task, _PersistEventsTask): return await self._persist_event_batch(room_id, task) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d9df437e51..e4162f846b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from canonicaljson import encode_canonical_json from typing_extensions import Literal from synapse.api.constants import EduTypes @@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) def _store_dehydrated_device_txn( - self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + device_data: str, + time: int, + keys: Optional[JsonDict] = None, ) -> Optional[str]: + # TODO: make keys non-optional once support for msc2697 is dropped + if keys: + device_keys = keys.get("device_keys", None) + if device_keys: + # Type ignore - this function is defined on EndToEndKeyStore which we do + # have access to due to hs.get_datastore() "magic" + self._set_e2e_device_keys_txn( # type: ignore[attr-defined] + txn, user_id, device_id, time, device_keys + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append( + ( + algorithm, + key_id, + encode_canonical_json(key_obj).decode("ascii"), + ) + ) + self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list) + + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys: + self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys) + old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, table="dehydrated_devices", @@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): keyvalues={"user_id": user_id}, values={"device_id": device_id, "device_data": device_data}, ) + return old_device_id async def store_dehydrated_device( - self, user_id: str, device_id: str, device_data: JsonDict + self, + user_id: str, + device_id: str, + device_data: JsonDict, + time_now: int, + keys: Optional[dict] = None, ) -> Optional[str]: """Store a dehydrated device for a user. @@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: the user that we are storing the device for device_id: the ID of the dehydrated device device_data: the dehydrated device information + time_now: current time at the request in milliseconds + keys: keys for the dehydrated device + Returns: device id of the user's previous dehydrated device, if any """ + return await self.db_pool.runInteraction( "store_dehydrated_device_txn", self._store_dehydrated_device_txn, user_id, device_id, json_encoder.encode(device_data), + time_now, + keys, ) async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 91ae9c457d..b49dea577c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", str(new_keys)) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self.db_pool.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - keys=( - "user_id", - "device_id", - "algorithm", - "key_id", - "ts_added_ms", - "key_json", - ), - values=[ - (user_id, device_id, algorithm, key_id, time_now, json_bytes) - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - await self.db_pool.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + "add_e2e_one_time_keys_insert", + self._add_e2e_one_time_keys_txn, + user_id, + device_id, + time_now, + new_keys, + ) + + def _add_e2e_one_time_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: + """Insert some new one time keys for a device. Errors if any of the keys already exist. + + Args: + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note + that the key JSON must be in canonical JSON form + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", str(new_keys)) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), + values=[ + (user_id, device_id, algorithm, key_id, time_now, json_bytes) + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) @cached(max_entries=10000) @@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_id: str, fallback_keys: JsonDict, ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded @@ -1304,43 +1333,70 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store """ - def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", str(device_keys)) - - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn + "set_e2e_device_keys", + self._set_e2e_device_keys_txn, + user_id, + device_id, + time_now, + device_keys, ) + def _set_e2e_device_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + device_keys: JsonDict, + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", str(device_keys)) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv( diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index fff417f9e3..047de6283a 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json -from typing_extensions import TYPE_CHECKING from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1666e3c43b..a3b4744855 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,14 +16,13 @@ import itertools import json import logging -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction -from synapse.storage.keys import FetchKeyResult +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -34,7 +33,7 @@ logger = logging.getLogger(__name__) db_binary_type = memoryview -class KeyStore(SQLBaseStore): +class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" @cached() @@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore): # invalidate takes a tuple corresponding to the params of # _get_server_keys_json. _get_server_keys_json only takes one # param, which is itself the 2-tuple (server_name, key_id). - self._get_server_keys_json.invalidate((((server_name, key_id),))) + await self.invalidate_cache_and_stream( + "_get_server_keys_json", ((server_name, key_id),) + ) + await self.invalidate_cache_and_stream( + "get_server_key_json_for_remote", (server_name, key_id) + ) @cached() def _get_server_keys_json( @@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore): return await self.db_pool.runInteraction("get_server_keys_json", _txn) + @cached() + def get_server_key_json_for_remote( + self, + server_name: str, + key_id: str, + ) -> Optional[FetchKeyResultForRemote]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_server_key_json_for_remote", list_name="key_ids" + ) async def get_server_keys_json_for_remote( - self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - """Retrieve the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, Optional[FetchKeyResultForRemote]]: + """Fetch the cached keys for the given server/key IDs. - Args: - server_keys: List of (server_name, key_id, source) triplets. - - Returns: - A mapping from (server_name, key_id, source) triplets to a list of dicts + If we have multiple entries for a given key ID, returns the most recent. """ - - def _get_server_keys_json_txn( - txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self.db_pool.simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results - - return await self.db_pool.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn + rows = await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ) + + if not rows: + return {} + + # We sort the rows so that the most recently added entry is picked up. + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } + + async def get_all_server_keys_json_for_remote( + self, + server_name: str, + ) -> Dict[str, FetchKeyResultForRemote]: + """Fetch the cached keys for the given server. + + If we have multiple entries for a given key ID, returns the most recent. + """ + rows = await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", + ) + + if not rows: + return {} + + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 1680bf6168..54d40e7a3a 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -26,7 +26,6 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.engines import PostgresEngine from synapse.util import Clock from synapse.util.stringutils import random_string @@ -96,6 +95,10 @@ class LockStore(SQLBaseStore): self._acquiring_locks: Set[Tuple[str, str]] = set() + self._clock.looping_call( + self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0 + ) + @wrap_as_background_process("LockStore._on_shutdown") async def _on_shutdown(self) -> None: """Called when the server is shutting down""" @@ -216,6 +219,7 @@ class LockStore(SQLBaseStore): lock_name, lock_key, write, + db_autocommit=True, ) except self.database_engine.module.IntegrityError: return None @@ -233,61 +237,22 @@ class LockStore(SQLBaseStore): # `worker_read_write_locks` and seeing if that fails any # constraints. If it doesn't then we have acquired the lock, # otherwise we haven't. - # - # Before that though we clear the table of any stale locks. now = self._clock.time_msec() token = random_string(6) - delete_sql = """ - DELETE FROM worker_read_write_locks - WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?; - """ - - insert_sql = """ - INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts) - VALUES (?, ?, ?, ?, ?, ?) - """ - - if isinstance(self.database_engine, PostgresEngine): - # For Postgres we can send these queries at the same time. - txn.execute( - delete_sql + ";" + insert_sql, - ( - # DELETE args - now - _LOCK_TIMEOUT_MS, - lock_name, - lock_key, - # UPSERT args - lock_name, - lock_key, - write, - self._instance_name, - token, - now, - ), - ) - else: - # For SQLite these need to be two queries. - txn.execute( - delete_sql, - ( - now - _LOCK_TIMEOUT_MS, - lock_name, - lock_key, - ), - ) - txn.execute( - insert_sql, - ( - lock_name, - lock_key, - write, - self._instance_name, - token, - now, - ), - ) + self.db_pool.simple_insert_txn( + txn, + table="worker_read_write_locks", + values={ + "lock_name": lock_name, + "lock_key": lock_key, + "write_lock": write, + "instance_name": self._instance_name, + "token": token, + "last_renewed_ts": now, + }, + ) lock = Lock( self._reactor, @@ -351,6 +316,24 @@ class LockStore(SQLBaseStore): return locks + @wrap_as_background_process("_reap_stale_read_write_locks") + async def _reap_stale_read_write_locks(self) -> None: + delete_sql = """ + DELETE FROM worker_read_write_locks + WHERE last_renewed_ts < ? + """ + + def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None: + txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,)) + if txn.rowcount: + logger.info("Reaped %d stale locks", txn.rowcount) + + await self.db_pool.runInteraction( + "_reap_stale_read_write_locks", + reap_stale_read_write_locks_txn, + db_autocommit=True, + ) + class Lock: """An async context manager that manages an acquired lock, ensuring it is diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c582cf0573..d3a01d526f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): name, password_hash, is_guest, admin, consent_version, consent_ts, consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, - COALESCE(approved, TRUE) AS approved + COALESCE(approved, TRUE) AS approved, + COALESCE(locked, FALSE) AS locked FROM users WHERE name = ? """, @@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # want to make sure we're returning the right type of data. # Note: when adding a column name to this list, be wary of NULLable columns, # since NULL values will be turned into False. - boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] + boolean_columns = [ + "admin", + "deactivated", + "shadow_banned", + "approved", + "locked", + ] for column in boolean_columns: - if not isinstance(row[column], bool): - row[column] = bool(row[column]) + row[column] = bool(row[column]) return row @@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Convert the integer into a boolean. return res == 1 + @cached() + async def get_user_locked_status(self, user_id: str) -> bool: + """Retrieve the value for the `locked` property for the provided user. + + Args: + user_id: The ID of the user to retrieve the status for. + + Returns: + True if the user was locked, false if the user is still active. + """ + + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="locked", + desc="get_user_locked_status", + ) + + # Convert the potential integer into a boolean. + return bool(res) + async def get_threepid_validation_session( self, medium: Optional[str], @@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: + """Set the `locked` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + locked: The value to set for `locked`. + """ + + await self.db_pool.runInteraction( + "set_user_locked_status", + self.set_user_locked_status_txn, + user_id, + locked, + ) + + def set_user_locked_status_txn( + self, txn: LoggingTransaction, user_id: str, locked: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"locked": locked}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + def update_user_approval_status_txn( self, txn: LoggingTransaction, user_id: str, approved: bool ) -> None: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index f34b7ce8f4..6298f0984d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -19,6 +19,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, Any, + Counter, Dict, Iterable, List, @@ -28,8 +29,6 @@ from typing import ( cast, ) -from typing_extensions import Counter - from twisted.internet.defer import DeferredLock from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 2a136f2ff6..f0dc31fee6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) async def search_user_dir( - self, user_id: str, search_term: str, limit: int + self, + user_id: str, + search_term: str, + limit: int, + show_locked_users: bool = False, ) -> SearchResult: """Searches for users in directory @@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + if not show_locked_users: + where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)" + # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] @@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM matching_users as t INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users AS u ON t.user_id = u.name WHERE %(where_clause)s ORDER BY @@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users AS u ON t.user_id = u.name WHERE %(where_clause)s AND value MATCH ? diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 0363cdc038..0b5b3bf03e 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM This is not provided by DBAPI2, and so needs engine-specific support. """ - with open(filepath, "rt") as f: + with open(filepath) as f: cls.executescript(cursor, f.read()) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 71584f3f74..e74b2269d2 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -25,3 +25,10 @@ logger = logging.getLogger(__name__) class FetchKeyResult: verify_key: VerifyKey # the key itself valid_until_ts: int # how long we can use this key for + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FetchKeyResultForRemote: + key_json: bytes # the full key JSON + valid_until_ts: int # how long we can use this key for, in milliseconds. + added_ts: int # When we added this key, in milliseconds. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 38b7abd801..31501fd573 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -16,10 +16,18 @@ import logging import os import re from collections import Counter -from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple +from typing import ( + Collection, + Counter as CounterType, + Generator, + Iterable, + List, + Optional, + TextIO, + Tuple, +) import attr -from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction diff --git a/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql new file mode 100644 index 0000000000..21c7971441 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql @@ -0,0 +1,16 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL; diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 39a1ae4ac3..073f682aca 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -21,6 +21,7 @@ from typing import ( Any, ClassVar, Dict, + Final, List, Mapping, Match, @@ -38,7 +39,7 @@ import attr from immutabledict import immutabledict from signedjson.key import decode_verify_key_bytes from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 4041e49e71..943ad54456 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -22,6 +22,7 @@ import logging from contextlib import asynccontextmanager from typing import ( Any, + AsyncContextManager, AsyncIterator, Awaitable, Callable, @@ -42,7 +43,7 @@ from typing import ( ) import attr -from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, Literal, ParamSpec from twisted.internet import defer from twisted.internet.defer import CancelledError diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index 644c341e8c..db6c40a3e1 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -218,7 +218,7 @@ class MacaroonGenerator: # to avoid validating those as guest tokens, we explicitely verify if # the macaroon includes the "guest = true" caveat. is_guest = any( - (caveat.caveat_id == "guest = true" for caveat in macaroon.caveats) + caveat.caveat_id == "guest = true" for caveat in macaroon.caveats ) if not is_guest: diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 48b8195ca1..8cb766860e 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory: SynapseManhole, dict(globals, __name__="__console__") ) - factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) + # type-ignore: This is an error in Twisted's annotations. See + # https://github.com/twisted/twisted/issues/11812 and /11813 . + factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type] # conch has the wrong type on these dicts (says bytes to bytes, # should be bytes to Keys judging by how it's used). diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2ad55ac13e..cde4a0780f 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -20,6 +20,7 @@ import typing from typing import ( Any, Callable, + ContextManager, DefaultDict, Dict, Iterator, @@ -33,7 +34,6 @@ from typing import ( from weakref import WeakSet from prometheus_client.core import Counter -from typing_extensions import ContextManager from twisted.internet import defer diff --git a/synapse/visibility.py b/synapse/visibility.py index fc71dc92a4..eac10f6438 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,6 +17,7 @@ from enum import Enum, auto from typing import ( Collection, Dict, + Final, FrozenSet, List, Mapping, @@ -27,7 +28,6 @@ from typing import ( ) import attr -from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cdb0048122..ce96574915 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -69,6 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.get_user_by_access_token = simple_async_mock(user_info) self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = simple_async_mock(False) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -293,6 +294,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.insert_client_ip = simple_async_mock(None) self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = simple_async_mock(False) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -311,6 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase): token_used=True, ) ) + self.store.get_user_locked_status = simple_async_mock(False) self.store.insert_client_ip = simple_async_mock(None) self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 9305b758d7..93af614def 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): def make_homeserver( self, reactor: ThreadedMemoryReactorClock, clock: Clock ) -> HomeServer: - hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + hs = super().make_homeserver(reactor, clock) # We don't want our tests to actually report statistics, so check # that it's not enabled diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 7c63b2ea4c..2be341ac7b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): [("server9", get_key_id(key1))] ) result = self.get_success(d) - self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) + self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0) def test_verify_json_dedupes_key_requests(self) -> None: """Two requests for the same key should be deduped.""" @@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], SERVER_NAME) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) # we expect it to be encoded as canonical json *before* it hits the db - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" @@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_get_multiple_keys_from_perspectives(self) -> None: """Check that we can correctly request multiple keys for the same server""" @@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_invalid_perspectives_responses(self) -> None: """Check that invalid responses from the perspectives server are rejected""" diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 647ee09279..e1e58fa6e6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(len(res["events"]), 1) self.assertEqual(res["events"][0]["content"]["body"], "foo") - # Fetch the message of the dehydrated device again, which should return nothing - # and delete the old messages + # Fetch the message of the dehydrated device again, which should return + # the same message as it has not been deleted res = self.get_success( self.message_handler.get_events_for_dehydrated_device( requester=requester, device_id=stored_dehydrated_device_id, - since_token=res["next_batch"], + since_token=None, limit=10, ) ) self.assertTrue(len(res["next_batch"]) > 1) - self.assertEqual(len(res["events"]), 0) + self.assertEqual(len(res["events"]), 1) + self.assertEqual(res["events"][0]["content"]["body"], "foo") diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 6309d7b36e..82c26e303f 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase): error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) + def test_introspection_token_cache(self) -> None: + access_token = "open_sesame" + self.http_client.request = simple_async_mock( + return_value=FakeResponse.json( + code=200, + payload={"active": "true", "scope": "guest", "jti": access_token}, + ) + ) + + # first call should cache response + # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code + # for regular auth code via the config + self.get_success( + self.auth._introspect_token(access_token) # type: ignore[attr-defined] + ) + introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined] + self.assertEqual(introspection_token["jti"], access_token) + # there's been one http request + self.http_client.request.assert_called_once() + + # second call should pull from cache, there should still be only one http request + token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined] + self.http_client.request.assert_called_once() + self.assertEqual(token["jti"], access_token) + + # advance past five minutes and check that cache expired - there should be more than one http call now + self.reactor.advance(360) + token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined] + self.assertEqual(self.http_client.request.call_count, 2) + self.assertEqual(token_2["jti"], access_token) + + # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a + # token with a soon-to-expire `exp` field to the cache + self.http_client.request = simple_async_mock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": "true", + "scope": "guest", + "jti": "stale", + "exp": self.clock.time() + 100, + }, + ) + ) + self.get_success( + self.auth._introspect_token("stale") # type: ignore[attr-defined] + ) + introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined] + self.assertEqual(introspection_token["jti"], "stale") + self.assertEqual(self.http_client.request.call_count, 1) + + # advance the reactor past the token expiry but less than the cache expiry + self.reactor.advance(120) + self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined] + + # check that the next call causes another http request (which will fail because the token is technically expired + # but the important thing is we discard the token from the cache and try the network) + self.get_failure( + self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined] + ) + self.assertEqual(self.http_client.request.call_count, 2) + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: # We only generate a master key to simplify the test. master_signing_key = generate_signing_key(device_id) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index aed2a4c07a..6a0b5fc0bd 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index b3310abe1b..fe631d7ecb 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(channel.json_body["creator"], user_id) # Check room alias. - self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}") + self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}") # Let's try a room with no alias. room_id, room_alias = self.get_success( diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 1527b4a82d..6e78daa830 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode(), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9af9db6e3e..41a959b4d6 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -29,7 +29,16 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.media.filepath import MediaFilePaths -from synapse.rest.client import devices, login, logout, profile, register, room, sync +from synapse.rest.client import ( + devices, + login, + logout, + profile, + register, + room, + sync, + user_directory, +) from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): login.register_servlets, sync.register_servlets, register.register_servlets, + user_directory.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -2464,6 +2474,105 @@ class UserRestTestCase(unittest.HomeserverTestCase): # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", channel.json_body) + def test_locked_user(self) -> None: + # User can sync + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is not authorized to sync anymore + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + ) + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"]) + self.assertTrue(channel.json_body["soft_logout"]) + + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) + def test_locked_user_not_in_user_dir(self) -> None: + # User is available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is not available anymore in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(0, len(channel.json_body["results"])) + + @override_config( + { + "user_directory": { + "enabled": True, + "search_all_users": True, + "show_locked_users": True, + } + } + ) + def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None: + # User is available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is still available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self) -> None: """ diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index 3cf29c10ea..60099f8c59 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError from synapse.rest import admin, devices, room, sync from synapse.rest.client import account, keys, login, register from synapse.server import HomeServer -from synapse.types import JsonDict, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): "": {":": ""} }, }, + "fallback_keys": { + "alg1:device1": "f4llb4ckk3y", + "signed_:": { + "fallback": "true", + "key": "f4llb4ckk3y", + "signatures": { + "": {":": ""} + }, + }, + }, + "one_time_keys": {"alg1:k1": "0net1m3k3y"}, } channel = self.make_request( "PUT", @@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): } self.assertEqual(device_data, expected_device_data) + # test that the keys are correctly uploaded + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device1"], + }, + }, + token, + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["device_keys"][user][device_id]["keys"], + content["device_keys"]["keys"], + ) + # first claim should return the onetime key we uploaded + res = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}}, + }, + ) + # second claim should return fallback key + res2 = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}}, + }, + ) + # create another device for the user ( new_device_id, @@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) expected_content = {"body": "test_message"} self.assertEqual(channel.json_body["events"][0]["content"], expected_content) + + # fetch messages again and make sure that the message was not deleted + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content={}, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["events"][0]["content"], expected_content) next_batch_token = channel.json_body.get("next_batch") - # fetch messages again and make sure that the message was deleted and we are returned an - # empty array + # make sure fetching messages with next batch token works - there are no unfetched + # messages so we should receive an empty array content = {"next_batch": next_batch_token} channel = self.make_request( "POST", diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 180b635ea6..4e0a387bd3 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase): redact_event = timeline[-1] self.assertEqual(redact_event["type"], EventTypes.Redaction) # The redacts key should be in the content and the redacts keys. - self.assertEquals(redact_event["content"]["redacts"], event_id) - self.assertEquals(redact_event["redacts"], event_id) + self.assertEqual(redact_event["content"]["redacts"], event_id) + self.assertEqual(redact_event["redacts"], event_id) # But it isn't actually part of the event. def get_event(txn: LoggingTransaction) -> JsonDict: @@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase): event_json = self.get_success( main_datastore.db_pool.runInteraction("get_event", get_event) ) - self.assertEquals(event_json["type"], EventTypes.Redaction) + self.assertEqual(event_json["type"], EventTypes.Redaction) if expect_content: self.assertNotIn("redacts", event_json) - self.assertEquals(event_json["content"]["redacts"], event_id) + self.assertEqual(event_json["content"]["redacts"], event_id) else: - self.assertEquals(event_json["redacts"], event_id) + self.assertEqual(event_json["redacts"], event_id) self.assertNotIn("redacts", event_json["content"]) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 75439416c1..9bfe913e45 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) return [ev["event_id"] for ev in channel.json_body["chunk"]] def _get_bundled_aggregations(self) -> JsonDict: @@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: @@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threads = channel.json_body["chunk"] return [ ( @@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ################################################## # Check the test data is configured as expected. # ################################################## - self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() self.assertDictContainsSubset( {"count": 3, "current_user_participated": True}, @@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(thread_replies.pop()) # The thread should still exist, but the latest event should be updated. - self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() self.assertDictContainsSubset( {"count": 2, "current_user_participated": True}, @@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(thread_replies.pop(0)) # Nothing should have changed (except the thread count). - self.assertEquals(self._get_related_events(), thread_replies) + self.assertEqual(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() self.assertDictContainsSubset( {"count": 1, "current_user_participated": True}, @@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the last remaining event. # #################################### self._redact(thread_replies.pop(0)) - self.assertEquals(thread_replies, []) + self.assertEqual(thread_replies, []) # The event should no longer be considered a thread. - self.assertEquals(self._get_related_events(), []) - self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_related_events(), []) + self.assertEqual(self._get_bundled_aggregations(), {}) self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: @@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # The relations are returned. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [related_event_id]) - self.assertEquals( + self.assertEqual(event_ids, [related_event_id]) + self.assertEqual( relations[RelationTypes.REFERENCE], {"chunk": [{"event_id": related_event_id}]}, ) @@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # The unredacted relation should still exist. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() - self.assertEquals(len(event_ids), 1) + self.assertEqual(len(event_ids), 1) self.assertDictContainsSubset( { "count": 1, @@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threads = self._get_threads(channel.json_body) self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) @@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Tuple of (thread ID, latest event ID) for each thread. threads = self._get_threads(channel.json_body) self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) @@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2]) @@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_1], channel.json_body) @@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual( thread_roots, [thread_3, thread_2, thread_1], channel.json_body @@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) @@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_1], channel.json_body) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4f6347be15..88e579dc39 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) def test_send_state_event_ts(self) -> None: """Test sending a state event with a custom timestamp.""" @@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) def test_send_membership_event_ts(self) -> None: """Test sending a membership event with a custom timestamp.""" @@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) class RoomJoinRatelimitTestCase(RoomBase): diff --git a/tests/server.py b/tests/server.py index c84a524e8c..481fe34c5c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -26,6 +26,7 @@ from typing import ( Any, Awaitable, Callable, + Deque, Dict, Iterable, List, @@ -41,7 +42,7 @@ from typing import ( from unittest.mock import Mock import attr -from typing_extensions import Deque, ParamSpec +from typing_extensions import ParamSpec from zope.interface import implementer from twisted.internet import address, threads, udp diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 383da83dfb..f541f1d6be 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. + from twisted.internet import defer, reactor from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS +from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS from synapse.util import Clock from tests import unittest @@ -380,8 +381,8 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase): self.get_success(lock.__aenter__()) # Wait for ages with the lock, we should not be able to get the lock. - self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000) - self.pump() + for _ in range(0, 10): + self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000)) lock2 = self.get_success( self.store.try_acquire_read_write_lock("name", "key", write=True) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 5e1324a169..71302facd1 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(ApplicationServiceStoreTestCase, self).setUp() + super().setUp() self.as_yaml_files: List[str] = [] @@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): except Exception: pass - super(ApplicationServiceStoreTestCase, self).tearDown() + super().tearDown() def _add_appservice( self, as_token: str, id: str, url: str, hs_token: str, sender: str @@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(ApplicationServiceTransactionStoreTestCase, self).setUp() + super().setUp() self.as_yaml_files: List[str] = [] self.hs.config.appservice.app_service_config_files = self.as_yaml_files diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 27f450e22d..b8823d6993 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -20,7 +20,7 @@ from tests import unittest class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(DataStoreTestCase, self).setUp() + super().setUp() self.store = self.hs.get_datastores().main diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 05ea802008..ba41459d08 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -48,6 +48,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "creation_ts": 0, "user_type": None, "deactivated": 0, + "locked": 0, "shadow_banned": 0, "approved": 1, }, diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index f183c38477..52ffa91c81 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase): result = self.get_success( store.search_msgs([self.room_id], query, ["content.body"]) ) - self.assertEquals( + self.assertEqual( result["count"], 1 if expect_to_contain else 0, f"expected '{query}' to match '{self.PHRASE}'" if expect_to_contain else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) - self.assertEquals( + self.assertEqual( len(result["results"]), 1 if expect_to_contain else 0, "results array length should match count", @@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase): result = self.get_success( store.search_rooms([self.room_id], query, ["content.body"], 10) ) - self.assertEquals( + self.assertEqual( result["count"], 1 if expect_to_contain else 0, f"expected '{query}' to match '{self.PHRASE}'" if expect_to_contain else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) - self.assertEquals( + self.assertEqual( len(result["results"]), 1 if expect_to_contain else 0, "results array length should match count", diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 9ed330f554..a46c29ddf4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(FilterEventsForServerTestCase, self).setUp() + super().setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self._storage_controllers = self.hs.get_storage_controllers()