Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
de16789d87
|
@ -5,6 +5,9 @@ on:
|
||||||
- cron: 0 8 * * *
|
- cron: 0 8 * * *
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
# NB: inputs are only present when this workflow is dispatched manually.
|
||||||
|
# (The default below is the default field value in the form to trigger
|
||||||
|
# a manual dispatch). Otherwise the inputs will evaluate to null.
|
||||||
inputs:
|
inputs:
|
||||||
twisted_ref:
|
twisted_ref:
|
||||||
description: Commit, branch or tag to checkout from upstream Twisted.
|
description: Commit, branch or tag to checkout from upstream Twisted.
|
||||||
|
@ -49,7 +52,7 @@ jobs:
|
||||||
extras: "all"
|
extras: "all"
|
||||||
- run: |
|
- run: |
|
||||||
poetry remove twisted
|
poetry remove twisted
|
||||||
poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref }}
|
poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }}
|
||||||
poetry install --no-interaction --extras "all test"
|
poetry install --no-interaction --extras "all test"
|
||||||
- name: Remove warn_unused_ignores from mypy config
|
- name: Remove warn_unused_ignores from mypy config
|
||||||
run: sed '/warn_unused_ignores = True/d' -i mypy.ini
|
run: sed '/warn_unused_ignores = True/d' -i mypy.ini
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
# Synapse 1.90.0 (2023-08-15)
|
||||||
|
|
||||||
|
No significant changes since 1.90.0rc1.
|
||||||
|
|
||||||
|
|
||||||
# Synapse 1.90.0rc1 (2023-08-08)
|
# Synapse 1.90.0rc1 (2023-08-08)
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
|
@ -132,9 +132,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.19"
|
version = "0.4.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
|
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Implements an admin API to lock an user without deactivating them. Based on [MSC3939](https://github.com/matrix-org/matrix-spec-proposals/pull/3939).
|
|
@ -0,0 +1 @@
|
||||||
|
Update dehydrated devices implementation.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix long-standing bug where concurrent requests to change a user's push rules could cause a deadlock. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -0,0 +1 @@
|
||||||
|
Fix database performance of read/write worker locks.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bu in `/sync` where timeout=0 does not skip caching, resulting in slow calls in cases where there are no new changes. Contributed by @PlasmaIntec.
|
|
@ -0,0 +1 @@
|
||||||
|
Override global statement timeout when creating indexes in Postgres.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix the type annotation on `run_db_interaction` in the Module API.
|
|
@ -0,0 +1 @@
|
||||||
|
Structured logging docs: add a link to explain the ELK stack
|
|
@ -0,0 +1 @@
|
||||||
|
Clean-up the presence code.
|
|
@ -0,0 +1 @@
|
||||||
|
Allow customising the IdP display name, icon, and brand for SAML and CAS providers (in addition to OIDC provider).
|
|
@ -0,0 +1 @@
|
||||||
|
Run `pyupgrade` for Python 3.8+.
|
|
@ -0,0 +1 @@
|
||||||
|
Rename pagination and purge locks and add comments to explain why they exist and how they work.
|
|
@ -0,0 +1 @@
|
||||||
|
Attempt to fix the twisted trunk job.
|
|
@ -0,0 +1 @@
|
||||||
|
Cache token introspection response from OIDC provider.
|
|
@ -0,0 +1 @@
|
||||||
|
Add cache to `get_server_keys_json_for_remote`.
|
|
@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path):
|
||||||
global CONFIG_JSON
|
global CONFIG_JSON
|
||||||
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
|
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r") as config:
|
with open(config_path) as config:
|
||||||
syn_cmd.config = json.load(config)
|
syn_cmd.config = json.load(config)
|
||||||
try:
|
try:
|
||||||
http_client.verbose = "on" == syn_cmd.config["verbose"]
|
http_client.verbose = "on" == syn_cmd.config["verbose"]
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.90.0) stable; urgency=medium
|
||||||
|
|
||||||
|
* New Synapse release 1.90.0.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 15 Aug 2023 11:17:34 +0100
|
||||||
|
|
||||||
matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium
|
matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium
|
||||||
|
|
||||||
* New Synapse release 1.90.0rc1.
|
* New Synapse release 1.90.0rc1.
|
||||||
|
|
|
@ -861,7 +861,7 @@ def generate_worker_files(
|
||||||
# Then a worker config file
|
# Then a worker config file
|
||||||
convert(
|
convert(
|
||||||
"/conf/worker.yaml.j2",
|
"/conf/worker.yaml.j2",
|
||||||
"/conf/workers/{name}.yaml".format(name=worker_name),
|
f"/conf/workers/{worker_name}.yaml",
|
||||||
**worker_config,
|
**worker_config,
|
||||||
worker_log_config_filepath=log_config_filepath,
|
worker_log_config_filepath=log_config_filepath,
|
||||||
using_unix_sockets=using_unix_sockets,
|
using_unix_sockets=using_unix_sockets,
|
||||||
|
|
|
@ -82,7 +82,7 @@ def generate_config_from_template(
|
||||||
with open(filename) as handle:
|
with open(filename) as handle:
|
||||||
value = handle.read()
|
value = handle.read()
|
||||||
else:
|
else:
|
||||||
log("Generating a random secret for {}".format(secret))
|
log(f"Generating a random secret for {secret}")
|
||||||
value = codecs.encode(os.urandom(32), "hex").decode()
|
value = codecs.encode(os.urandom(32), "hex").decode()
|
||||||
with open(filename, "w") as handle:
|
with open(filename, "w") as handle:
|
||||||
handle.write(value)
|
handle.write(value)
|
||||||
|
|
|
@ -146,6 +146,7 @@ Body parameters:
|
||||||
- `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator,
|
- `admin` - **bool**, optional, defaults to `false`. Whether the user is a homeserver administrator,
|
||||||
granting them access to the Admin API, among other things.
|
granting them access to the Admin API, among other things.
|
||||||
- `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged.
|
- `deactivated` - **bool**, optional. If unspecified, deactivation state will be left unchanged.
|
||||||
|
- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
|
||||||
|
|
||||||
Note: the `password` field must also be set if both of the following are true:
|
Note: the `password` field must also be set if both of the following are true:
|
||||||
- `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user)
|
- `deactivated` is set to `false` and the user was previously deactivated (you are reactivating this user)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
A structured logging system can be useful when your logs are destined for a
|
A structured logging system can be useful when your logs are destined for a
|
||||||
machine to parse and process. By maintaining its machine-readable characteristics,
|
machine to parse and process. By maintaining its machine-readable characteristics,
|
||||||
it enables more efficient searching and aggregations when consumed by software
|
it enables more efficient searching and aggregations when consumed by software
|
||||||
such as the "ELK stack".
|
such as the [ELK stack](https://opensource.com/article/18/9/open-source-log-aggregation-tools).
|
||||||
|
|
||||||
Synapse's structured logging system is configured via the file that Synapse's
|
Synapse's structured logging system is configured via the file that Synapse's
|
||||||
`log_config` config option points to. The file should include a formatter which
|
`log_config` config option points to. The file should include a formatter which
|
||||||
|
|
|
@ -3025,6 +3025,16 @@ enable SAML login. You can either put your entire pysaml config inline using the
|
||||||
option, or you can specify a path to a psyaml config file with the sub-option `config_path`.
|
option, or you can specify a path to a psyaml config file with the sub-option `config_path`.
|
||||||
This setting has the following sub-options:
|
This setting has the following sub-options:
|
||||||
|
|
||||||
|
* `idp_name`: A user-facing name for this identity provider, which is used to
|
||||||
|
offer the user a choice of login mechanisms.
|
||||||
|
* `idp_icon`: An optional icon for this identity provider, which is presented
|
||||||
|
by clients and Synapse's own IdP picker page. If given, must be an
|
||||||
|
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
|
||||||
|
obtain such an MXC URI is to upload an image to an (unencrypted) room
|
||||||
|
and then copy the "url" from the source of the event.)
|
||||||
|
* `idp_brand`: An optional brand for this identity provider, allowing clients
|
||||||
|
to style the login flow according to the identity provider in question.
|
||||||
|
See the [spec](https://spec.matrix.org/latest/) for possible options here.
|
||||||
* `sp_config`: the configuration for the pysaml2 Service Provider. See pysaml2 docs for format of config.
|
* `sp_config`: the configuration for the pysaml2 Service Provider. See pysaml2 docs for format of config.
|
||||||
Default values will be used for the `entityid` and `service` settings,
|
Default values will be used for the `entityid` and `service` settings,
|
||||||
so it is not normally necessary to specify them unless you need to
|
so it is not normally necessary to specify them unless you need to
|
||||||
|
@ -3176,7 +3186,7 @@ Options for each entry include:
|
||||||
|
|
||||||
* `idp_icon`: An optional icon for this identity provider, which is presented
|
* `idp_icon`: An optional icon for this identity provider, which is presented
|
||||||
by clients and Synapse's own IdP picker page. If given, must be an
|
by clients and Synapse's own IdP picker page. If given, must be an
|
||||||
MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
|
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
|
||||||
obtain such an MXC URI is to upload an image to an (unencrypted) room
|
obtain such an MXC URI is to upload an image to an (unencrypted) room
|
||||||
and then copy the "url" from the source of the event.)
|
and then copy the "url" from the source of the event.)
|
||||||
|
|
||||||
|
@ -3391,6 +3401,16 @@ Enable Central Authentication Service (CAS) for registration and login.
|
||||||
Has the following sub-options:
|
Has the following sub-options:
|
||||||
* `enabled`: Set this to true to enable authorization against a CAS server.
|
* `enabled`: Set this to true to enable authorization against a CAS server.
|
||||||
Defaults to false.
|
Defaults to false.
|
||||||
|
* `idp_name`: A user-facing name for this identity provider, which is used to
|
||||||
|
offer the user a choice of login mechanisms.
|
||||||
|
* `idp_icon`: An optional icon for this identity provider, which is presented
|
||||||
|
by clients and Synapse's own IdP picker page. If given, must be an
|
||||||
|
MXC URI of the format `mxc://<server-name>/<media-id>`. (An easy way to
|
||||||
|
obtain such an MXC URI is to upload an image to an (unencrypted) room
|
||||||
|
and then copy the "url" from the source of the event.)
|
||||||
|
* `idp_brand`: An optional brand for this identity provider, allowing clients
|
||||||
|
to style the login flow according to the identity provider in question.
|
||||||
|
See the [spec](https://spec.matrix.org/latest/) for possible options here.
|
||||||
* `server_url`: The URL of the CAS authorization endpoint.
|
* `server_url`: The URL of the CAS authorization endpoint.
|
||||||
* `displayname_attribute`: The attribute of the CAS response to use as the display name.
|
* `displayname_attribute`: The attribute of the CAS response to use as the display name.
|
||||||
If no name is given here, no displayname will be set.
|
If no name is given here, no displayname will be set.
|
||||||
|
@ -3631,6 +3651,7 @@ This option has the following sub-options:
|
||||||
* `prefer_local_users`: Defines whether to prefer local users in search query results.
|
* `prefer_local_users`: Defines whether to prefer local users in search query results.
|
||||||
If set to true, local users are more likely to appear above remote users when searching the
|
If set to true, local users are more likely to appear above remote users when searching the
|
||||||
user directory. Defaults to false.
|
user directory. Defaults to false.
|
||||||
|
* `show_locked_users`: Defines whether to show locked users in search query results. Defaults to false.
|
||||||
|
|
||||||
Example configuration:
|
Example configuration:
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -3638,6 +3659,7 @@ user_directory:
|
||||||
enabled: false
|
enabled: false
|
||||||
search_all_users: true
|
search_all_users: true
|
||||||
prefer_local_users: true
|
prefer_local_users: true
|
||||||
|
show_locked_users: true
|
||||||
```
|
```
|
||||||
---
|
---
|
||||||
### `user_consent`
|
### `user_consent`
|
||||||
|
|
7
mypy.ini
7
mypy.ini
|
@ -45,6 +45,13 @@ warn_unused_ignores = False
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
disallow_incomplete_defs = False
|
disallow_incomplete_defs = False
|
||||||
|
|
||||||
|
[mypy-synapse.util.manhole]
|
||||||
|
# This module imports something from Twisted which has a bad annotation in Twisted trunk,
|
||||||
|
# but is unannotated in Twisted's latest release. We want to type-ignore the problem
|
||||||
|
# in the twisted trunk job, even though it has no effect on normal mypy runs.
|
||||||
|
warn_unused_ignores = False
|
||||||
|
|
||||||
|
|
||||||
;; Dependencies without annotations
|
;; Dependencies without annotations
|
||||||
;; Before ignoring a module, check to see if type stubs are available.
|
;; Before ignoring a module, check to see if type stubs are available.
|
||||||
;; The `typeshed` project maintains stubs here:
|
;; The `typeshed` project maintains stubs here:
|
||||||
|
|
|
@ -589,13 +589,13 @@ smmap = ">=3.0.1,<6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gitpython"
|
name = "gitpython"
|
||||||
version = "3.1.31"
|
version = "3.1.32"
|
||||||
description = "GitPython is a Python library used to interact with Git repositories"
|
description = "GitPython is a Python library used to interact with Git repositories"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"},
|
{file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"},
|
||||||
{file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"},
|
{file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -887,17 +887,17 @@ scripts = ["click (>=6.0)", "twisted (>=16.4.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "isort"
|
name = "isort"
|
||||||
version = "5.11.5"
|
version = "5.12.0"
|
||||||
description = "A Python utility / library to sort Python imports."
|
description = "A Python utility / library to sort Python imports."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "isort-5.11.5-py3-none-any.whl", hash = "sha256:ba1d72fb2595a01c7895a5128f9585a5cc4b6d395f1c8d514989b9a7eb2a8746"},
|
{file = "isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"},
|
||||||
{file = "isort-5.11.5.tar.gz", hash = "sha256:6be1f76a507cb2ecf16c7cf14a37e41609ca082330be4e3436a18ef74add55db"},
|
{file = "isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
colors = ["colorama (>=0.4.3,<0.5.0)"]
|
colors = ["colorama (>=0.4.3)"]
|
||||||
pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"]
|
pipfile-deprecated-finder = ["pip-shims (>=0.5.2)", "pipreqs", "requirementslib"]
|
||||||
plugins = ["setuptools"]
|
plugins = ["setuptools"]
|
||||||
requirements-deprecated-finder = ["pip-api", "pipreqs"]
|
requirements-deprecated-finder = ["pip-api", "pipreqs"]
|
||||||
|
@ -2921,13 +2921,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "txredisapi"
|
name = "txredisapi"
|
||||||
version = "1.4.9"
|
version = "1.4.10"
|
||||||
description = "non-blocking redis client for python"
|
description = "non-blocking redis client for python"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "txredisapi-1.4.9-py3-none-any.whl", hash = "sha256:72e6ad09cc5fffe3bec2e55e5bfb74407bd357565fc212e6003f7e26ef7d8f78"},
|
{file = "txredisapi-1.4.10-py3-none-any.whl", hash = "sha256:0a6ea77f27f8cf092f907654f08302a97b48fa35f24e0ad99dfb74115f018161"},
|
||||||
{file = "txredisapi-1.4.9.tar.gz", hash = "sha256:c9607062d05e4d0b8ef84719eb76a3fe7d5ccd606a2acf024429da51d6e84559"},
|
{file = "txredisapi-1.4.10.tar.gz", hash = "sha256:7609a6af6ff4619a3189c0adfb86aeda789afba69eb59fc1e19ac0199e725395"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2936,13 +2936,13 @@ twisted = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "types-bleach"
|
name = "types-bleach"
|
||||||
version = "6.0.0.3"
|
version = "6.0.0.4"
|
||||||
description = "Typing stubs for bleach"
|
description = "Typing stubs for bleach"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "types-bleach-6.0.0.3.tar.gz", hash = "sha256:8ce7896d4f658c562768674ffcf07492c7730e128018f03edd163ff912bfadee"},
|
{file = "types-bleach-6.0.0.4.tar.gz", hash = "sha256:357b0226f65c4f20ab3b13ca8d78a6b91c78aad256d8ec168d4e90fc3303ebd4"},
|
||||||
{file = "types_bleach-6.0.0.3-py3-none-any.whl", hash = "sha256:d43eaf30a643ca824e16e2dcdb0c87ef9226237e2fa3ac4732a50cb3f32e145f"},
|
{file = "types_bleach-6.0.0.4-py3-none-any.whl", hash = "sha256:2b8767eb407c286b7f02803678732e522e04db8d56cbc9f1270bee49627eae92"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2991,13 +2991,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "types-pillow"
|
name = "types-pillow"
|
||||||
version = "10.0.0.1"
|
version = "10.0.0.2"
|
||||||
description = "Typing stubs for Pillow"
|
description = "Typing stubs for Pillow"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "types-Pillow-10.0.0.1.tar.gz", hash = "sha256:834a07a04504f8bf37936679bc6a5802945e7644d0727460c0c4d4307967e2a3"},
|
{file = "types-Pillow-10.0.0.2.tar.gz", hash = "sha256:fe09380ab22d412ced989a067e9ee4af719fa3a47ba1b53b232b46514a871042"},
|
||||||
{file = "types_Pillow-10.0.0.1-py3-none-any.whl", hash = "sha256:be576b67418f1cb3b93794cf7946581be1009a33a10085b3c132eb0875a819b4"},
|
{file = "types_Pillow-10.0.0.2-py3-none-any.whl", hash = "sha256:29d51a3ce6ef51fabf728a504d33b4836187ff14256b2e86996d55c91ab214b1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "matrix-synapse"
|
name = "matrix-synapse"
|
||||||
version = "1.90.0rc1"
|
version = "1.90.0"
|
||||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
@ -47,7 +47,7 @@ can be passed on the commandline for debugging.
|
||||||
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
class Builder(object):
|
class Builder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redirect_stdout: bool = False,
|
redirect_stdout: bool = False,
|
||||||
|
|
|
@ -43,7 +43,7 @@ def main(force_colors: bool) -> None:
|
||||||
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
|
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
|
||||||
|
|
||||||
# Get the schema version of the local file to check against current schema on develop
|
# Get the schema version of the local file to check against current schema on develop
|
||||||
with open("synapse/storage/schema/__init__.py", "r") as file:
|
with open("synapse/storage/schema/__init__.py") as file:
|
||||||
local_schema = file.read()
|
local_schema = file.read()
|
||||||
new_locals: Dict[str, Any] = {}
|
new_locals: Dict[str, Any] = {}
|
||||||
exec(local_schema, new_locals)
|
exec(local_schema, new_locals)
|
||||||
|
|
|
@ -247,7 +247,7 @@ def main() -> None:
|
||||||
|
|
||||||
|
|
||||||
def read_args_from_config(args: argparse.Namespace) -> None:
|
def read_args_from_config(args: argparse.Namespace) -> None:
|
||||||
with open(args.config, "r") as fh:
|
with open(args.config) as fh:
|
||||||
config = yaml.safe_load(fh)
|
config = yaml.safe_load(fh)
|
||||||
|
|
||||||
if not args.server_name:
|
if not args.server_name:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -145,7 +145,7 @@ Example usage:
|
||||||
|
|
||||||
|
|
||||||
def read_args_from_config(args: argparse.Namespace) -> None:
|
def read_args_from_config(args: argparse.Namespace) -> None:
|
||||||
with open(args.config, "r") as fh:
|
with open(args.config) as fh:
|
||||||
config = yaml.safe_load(fh)
|
config = yaml.safe_load(fh)
|
||||||
if not args.server_name:
|
if not args.server_name:
|
||||||
args.server_name = config["server_name"]
|
args.server_name = config["server_name"]
|
||||||
|
|
|
@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
|
||||||
from synapse.util.stringutils import strtobool
|
from synapse.util.stringutils import strtobool
|
||||||
|
|
||||||
# Check that we're not running on an unsupported Python version.
|
# Check that we're not running on an unsupported Python version.
|
||||||
if sys.version_info < (3, 8):
|
#
|
||||||
|
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
|
||||||
|
# if-statement completely.
|
||||||
|
py_version = sys.version_info
|
||||||
|
if py_version < (3, 8):
|
||||||
print("Synapse requires Python 3.8 or above.")
|
print("Synapse requires Python 3.8 or above.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
@ -78,7 +82,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import synapse.util
|
import synapse.util # noqa: E402
|
||||||
|
|
||||||
__version__ = synapse.util.SYNAPSE_VERSION
|
__version__ = synapse.util.SYNAPSE_VERSION
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ BOOLEAN_COLUMNS = {
|
||||||
"redactions": ["have_censored"],
|
"redactions": ["have_censored"],
|
||||||
"room_stats_state": ["is_federatable"],
|
"room_stats_state": ["is_federatable"],
|
||||||
"rooms": ["is_public", "has_auth_chain_index"],
|
"rooms": ["is_public", "has_auth_chain_index"],
|
||||||
"users": ["shadow_banned", "approved"],
|
"users": ["shadow_banned", "approved", "locked"],
|
||||||
"un_partial_stated_event_stream": ["rejection_status_changed"],
|
"un_partial_stated_event_stream": ["rejection_status_changed"],
|
||||||
"users_who_share_rooms": ["share_private"],
|
"users_who_share_rooms": ["share_private"],
|
||||||
"per_user_experimental_features": ["enabled"],
|
"per_user_experimental_features": ["enabled"],
|
||||||
|
@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
|
||||||
self.total_processed = 0
|
self.total_processed = 0
|
||||||
self.total_remaining = 0
|
self.total_remaining = 0
|
||||||
|
|
||||||
super(CursesProgress, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
def update(self, table: str, num_done: int) -> None:
|
def update(self, table: str, num_done: int) -> None:
|
||||||
super(CursesProgress, self).update(table, num_done)
|
super().update(table, num_done)
|
||||||
|
|
||||||
self.total_processed = 0
|
self.total_processed = 0
|
||||||
self.total_remaining = 0
|
self.total_remaining = 0
|
||||||
|
@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
|
||||||
"""Just prints progress to the terminal"""
|
"""Just prints progress to the terminal"""
|
||||||
|
|
||||||
def update(self, table: str, num_done: int) -> None:
|
def update(self, table: str, num_done: int) -> None:
|
||||||
super(TerminalProgress, self).update(table, num_done)
|
super().update(table, num_done)
|
||||||
|
|
||||||
data = self.tables[table]
|
data = self.tables[table]
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
||||||
|
|
||||||
def __init__(self, config: HomeServerConfig):
|
def __init__(self, config: HomeServerConfig):
|
||||||
super(MockHomeserver, self).__init__(
|
super().__init__(
|
||||||
hostname=config.server.server_name,
|
hostname=config.server.server_name,
|
||||||
config=config,
|
config=config,
|
||||||
reactor=reactor,
|
reactor=reactor,
|
||||||
|
|
|
@ -60,6 +60,7 @@ class Auth(Protocol):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
|
allow_locked: bool = False,
|
||||||
) -> Requester:
|
) -> Requester:
|
||||||
"""Get a registered user's ID.
|
"""Get a registered user's ID.
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ class InternalAuth(BaseAuth):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
|
allow_locked: bool = False,
|
||||||
) -> Requester:
|
) -> Requester:
|
||||||
"""Get a registered user's ID.
|
"""Get a registered user's ID.
|
||||||
|
|
||||||
|
@ -79,7 +80,7 @@ class InternalAuth(BaseAuth):
|
||||||
parent_span = active_span()
|
parent_span = active_span()
|
||||||
with start_active_span("get_user_by_req"):
|
with start_active_span("get_user_by_req"):
|
||||||
requester = await self._wrapped_get_user_by_req(
|
requester = await self._wrapped_get_user_by_req(
|
||||||
request, allow_guest, allow_expired
|
request, allow_guest, allow_expired, allow_locked
|
||||||
)
|
)
|
||||||
|
|
||||||
if parent_span:
|
if parent_span:
|
||||||
|
@ -107,6 +108,7 @@ class InternalAuth(BaseAuth):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
allow_guest: bool,
|
allow_guest: bool,
|
||||||
allow_expired: bool,
|
allow_expired: bool,
|
||||||
|
allow_locked: bool,
|
||||||
) -> Requester:
|
) -> Requester:
|
||||||
"""Helper for get_user_by_req
|
"""Helper for get_user_by_req
|
||||||
|
|
||||||
|
@ -126,6 +128,17 @@ class InternalAuth(BaseAuth):
|
||||||
access_token, allow_expired=allow_expired
|
access_token, allow_expired=allow_expired
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deny the request if the user account is locked.
|
||||||
|
if not allow_locked and await self.store.get_user_locked_status(
|
||||||
|
requester.user.to_string()
|
||||||
|
):
|
||||||
|
raise AuthError(
|
||||||
|
401,
|
||||||
|
"User account has been locked",
|
||||||
|
errcode=Codes.USER_LOCKED,
|
||||||
|
additional_fields={"soft_logout": True},
|
||||||
|
)
|
||||||
|
|
||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
# This check is only done for regular users, not appservice ones.
|
# This check is only done for regular users, not appservice ones.
|
||||||
if not allow_expired:
|
if not allow_expired:
|
||||||
|
|
|
@ -27,6 +27,7 @@ from twisted.web.http_headers import Headers
|
||||||
from synapse.api.auth.base import BaseAuth
|
from synapse.api.auth.base import BaseAuth
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
Codes,
|
||||||
HttpResponseException,
|
HttpResponseException,
|
||||||
InvalidClientTokenError,
|
InvalidClientTokenError,
|
||||||
OAuthInsufficientScopeError,
|
OAuthInsufficientScopeError,
|
||||||
|
@ -38,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import Requester, UserID, create_requester
|
from synapse.types import Requester, UserID, create_requester
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -105,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
|
|
||||||
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
|
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
|
||||||
|
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
|
||||||
|
cache_name="introspection_token_cache",
|
||||||
|
clock=self._clock,
|
||||||
|
max_len=10000,
|
||||||
|
expiry_ms=5 * 60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(auth_method, PrivateKeyJWTWithKid):
|
if isinstance(auth_method, PrivateKeyJWTWithKid):
|
||||||
# Use the JWK as the client secret when using the private_key_jwt method
|
# Use the JWK as the client secret when using the private_key_jwt method
|
||||||
assert self._config.jwk, "No JWK provided"
|
assert self._config.jwk, "No JWK provided"
|
||||||
|
@ -143,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
Returns:
|
Returns:
|
||||||
The introspection response
|
The introspection response
|
||||||
"""
|
"""
|
||||||
|
# check the cache before doing a request
|
||||||
|
introspection_token = self._token_cache.get(token, None)
|
||||||
|
|
||||||
|
if introspection_token:
|
||||||
|
# check the expiration field of the token (if it exists)
|
||||||
|
exp = introspection_token.get("exp", None)
|
||||||
|
if exp:
|
||||||
|
time_now = self._clock.time()
|
||||||
|
expired = time_now > exp
|
||||||
|
if not expired:
|
||||||
|
return introspection_token
|
||||||
|
else:
|
||||||
|
return introspection_token
|
||||||
|
|
||||||
metadata = await self._issuer_metadata.get()
|
metadata = await self._issuer_metadata.get()
|
||||||
introspection_endpoint = metadata.get("introspection_endpoint")
|
introspection_endpoint = metadata.get("introspection_endpoint")
|
||||||
raw_headers: Dict[str, str] = {
|
raw_headers: Dict[str, str] = {
|
||||||
|
@ -156,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
|
|
||||||
# Fill the body/headers with credentials
|
# Fill the body/headers with credentials
|
||||||
uri, raw_headers, body = self._client_auth.prepare(
|
uri, raw_headers, body = self._client_auth.prepare(
|
||||||
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
|
method="POST",
|
||||||
|
uri=introspection_endpoint,
|
||||||
|
headers=raw_headers,
|
||||||
|
body=body,
|
||||||
)
|
)
|
||||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||||
|
|
||||||
|
@ -186,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
"The introspection endpoint returned an invalid JSON response."
|
"The introspection endpoint returned an invalid JSON response."
|
||||||
)
|
)
|
||||||
|
|
||||||
return IntrospectionToken(**resp)
|
expiration = resp.get("exp", None)
|
||||||
|
if expiration:
|
||||||
|
if self._clock.time() > expiration:
|
||||||
|
raise InvalidClientTokenError("Token is expired.")
|
||||||
|
|
||||||
|
introspection_token = IntrospectionToken(**resp)
|
||||||
|
|
||||||
|
# add token to cache
|
||||||
|
self._token_cache[token] = introspection_token
|
||||||
|
|
||||||
|
return introspection_token
|
||||||
|
|
||||||
async def is_server_admin(self, requester: Requester) -> bool:
|
async def is_server_admin(self, requester: Requester) -> bool:
|
||||||
return "urn:synapse:admin:*" in requester.scope
|
return "urn:synapse:admin:*" in requester.scope
|
||||||
|
@ -196,6 +233,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
|
allow_locked: bool = False,
|
||||||
) -> Requester:
|
) -> Requester:
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
|
|
||||||
|
@ -205,6 +243,17 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||||
# so that we don't provision the user if they don't have enough permission:
|
# so that we don't provision the user if they don't have enough permission:
|
||||||
requester = await self.get_user_by_access_token(access_token, allow_expired)
|
requester = await self.get_user_by_access_token(access_token, allow_expired)
|
||||||
|
|
||||||
|
# Deny the request if the user account is locked.
|
||||||
|
if not allow_locked and await self.store.get_user_locked_status(
|
||||||
|
requester.user.to_string()
|
||||||
|
):
|
||||||
|
raise AuthError(
|
||||||
|
401,
|
||||||
|
"User account has been locked",
|
||||||
|
errcode=Codes.USER_LOCKED,
|
||||||
|
additional_fields={"soft_logout": True},
|
||||||
|
)
|
||||||
|
|
||||||
if not allow_guest and requester.is_guest:
|
if not allow_guest and requester.is_guest:
|
||||||
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
|
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,7 @@
|
||||||
"""Contains constants from the specification."""
|
"""Contains constants from the specification."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from typing import Final
|
||||||
from typing_extensions import Final
|
|
||||||
|
|
||||||
# the max size of a (canonical-json-encoded) event
|
# the max size of a (canonical-json-encoded) event
|
||||||
MAX_PDU_SIZE = 65536
|
MAX_PDU_SIZE = 65536
|
||||||
|
|
|
@ -80,6 +80,8 @@ class Codes(str, Enum):
|
||||||
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||||
|
# USER_LOCKED = "M_USER_LOCKED"
|
||||||
|
USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
|
||||||
|
|
||||||
# Part of MSC3848
|
# Part of MSC3848
|
||||||
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
|
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
|
||||||
|
|
|
@ -47,6 +47,10 @@ class CasConfig(Config):
|
||||||
required_attributes
|
required_attributes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.idp_name = cas_config.get("idp_name", "CAS")
|
||||||
|
self.idp_icon = cas_config.get("idp_icon")
|
||||||
|
self.idp_brand = cas_config.get("idp_brand")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.cas_server_url = None
|
self.cas_server_url = None
|
||||||
self.cas_service_url = None
|
self.cas_service_url = None
|
||||||
|
|
|
@ -89,8 +89,14 @@ class SAML2Config(Config):
|
||||||
"grandfathered_mxid_source_attribute", "uid"
|
"grandfathered_mxid_source_attribute", "uid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# refers to a SAML IdP entity ID
|
||||||
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
|
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
|
||||||
|
|
||||||
|
# IdP properties for Matrix clients
|
||||||
|
self.idp_name = saml2_config.get("idp_name", "SAML")
|
||||||
|
self.idp_icon = saml2_config.get("idp_icon")
|
||||||
|
self.idp_brand = saml2_config.get("idp_brand")
|
||||||
|
|
||||||
# user_mapping_provider may be None if the key is present but has no value
|
# user_mapping_provider may be None if the key is present but has no value
|
||||||
ump_dict = saml2_config.get("user_mapping_provider") or {}
|
ump_dict = saml2_config.get("user_mapping_provider") or {}
|
||||||
|
|
||||||
|
|
|
@ -35,3 +35,4 @@ class UserDirectoryConfig(Config):
|
||||||
self.user_directory_search_prefer_local_users = user_directory_config.get(
|
self.user_directory_search_prefer_local_users = user_directory_config.get(
|
||||||
"prefer_local_users", False
|
"prefer_local_users", False
|
||||||
)
|
)
|
||||||
|
self.show_locked_users = user_directory_config.get("show_locked_users", False)
|
||||||
|
|
|
@ -63,7 +63,7 @@ from synapse.federation.federation_base import (
|
||||||
)
|
)
|
||||||
from synapse.federation.persistence import TransactionActions
|
from synapse.federation.persistence import TransactionActions
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.http.servlet import assert_params_in_dict
|
from synapse.http.servlet import assert_params_in_dict
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
|
@ -1245,7 +1245,7 @@ class FederationServer(FederationBase):
|
||||||
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
|
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
|
||||||
# lock.
|
# lock.
|
||||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
await self._federation_event_handler.on_receive_pdu(
|
await self._federation_event_handler.on_receive_pdu(
|
||||||
origin, event
|
origin, event
|
||||||
|
|
|
@ -67,6 +67,7 @@ class AdminHandler:
|
||||||
"name",
|
"name",
|
||||||
"admin",
|
"admin",
|
||||||
"deactivated",
|
"deactivated",
|
||||||
|
"locked",
|
||||||
"shadow_banned",
|
"shadow_banned",
|
||||||
"creation_ts",
|
"creation_ts",
|
||||||
"appservice_id",
|
"appservice_id",
|
||||||
|
|
|
@ -76,12 +76,13 @@ class CasHandler:
|
||||||
self.idp_id = "cas"
|
self.idp_id = "cas"
|
||||||
|
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = "CAS"
|
self.idp_name = hs.config.cas.idp_name
|
||||||
|
|
||||||
# we do not currently support brands/icons for CAS auth, but this is required by
|
# MXC URI for icon for this auth provider
|
||||||
# the SsoIdentityProvider protocol type.
|
self.idp_icon = hs.config.cas.idp_icon
|
||||||
self.idp_icon = None
|
|
||||||
self.idp_brand = None
|
# optional brand identifier for this auth provider
|
||||||
|
self.idp_brand = hs.config.cas.idp_brand
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
|
|
@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
self.federation_sender = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
self._account_data_handler = hs.get_account_data_handler()
|
self._account_data_handler = hs.get_account_data_handler()
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
self.db_pool = hs.get_datastores().main.db_pool
|
||||||
|
|
||||||
self.device_list_updater = DeviceListUpdater(hs, self)
|
self.device_list_updater = DeviceListUpdater(hs, self)
|
||||||
|
|
||||||
|
@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
device_data: JsonDict,
|
device_data: JsonDict,
|
||||||
initial_device_display_name: Optional[str] = None,
|
initial_device_display_name: Optional[str] = None,
|
||||||
|
keys_for_device: Optional[JsonDict] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Store a dehydrated device for a user. If the user had a previous
|
"""Store a dehydrated device for a user, optionally storing the keys associated with
|
||||||
dehydrated device, it is removed.
|
it as well. If the user had a previous dehydrated device, it is removed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: the user that we are storing the device for
|
user_id: the user that we are storing the device for
|
||||||
device_id: device id supplied by client
|
device_id: device id supplied by client
|
||||||
device_data: the dehydrated device information
|
device_data: the dehydrated device information
|
||||||
initial_device_display_name: The display name to use for the device
|
initial_device_display_name: The display name to use for the device
|
||||||
|
keys_for_device: keys for the dehydrated device
|
||||||
Returns:
|
Returns:
|
||||||
device id of the dehydrated device
|
device id of the dehydrated device
|
||||||
"""
|
"""
|
||||||
|
@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
device_id,
|
device_id,
|
||||||
initial_device_display_name,
|
initial_device_display_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
old_device_id = await self.store.store_dehydrated_device(
|
old_device_id = await self.store.store_dehydrated_device(
|
||||||
user_id, device_id, device_data
|
user_id, device_id, device_data, time_now, keys_for_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if old_device_id is not None:
|
if old_device_id is not None:
|
||||||
await self.delete_devices(user_id, [old_device_id])
|
await self.delete_devices(user_id, [old_device_id])
|
||||||
|
|
||||||
return device_id
|
return device_id
|
||||||
|
|
||||||
async def rehydrate_device(
|
async def rehydrate_device(
|
||||||
|
|
|
@ -367,19 +367,6 @@ class DeviceMessageHandler:
|
||||||
errcode=Codes.INVALID_PARAM,
|
errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if we have a since token, delete any to-device messages before that token
|
|
||||||
# (since we now know that the device has received them)
|
|
||||||
deleted = await self.store.delete_messages_for_device(
|
|
||||||
user_id, device_id, since_stream_id
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Deleted %d to-device messages up to %d for user_id %s device_id %s",
|
|
||||||
deleted,
|
|
||||||
since_stream_id,
|
|
||||||
user_id,
|
|
||||||
device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
to_token = self.event_sources.get_current_token().to_device_key
|
to_token = self.event_sources.get_current_token().to_device_key
|
||||||
|
|
||||||
messages, stream_id = await self.store.get_messages_for_device(
|
messages, stream_id = await self.store.get_messages_for_device(
|
||||||
|
|
|
@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
|
||||||
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
|
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.handlers.directory import DirectoryHandler
|
from synapse.handlers.directory import DirectoryHandler
|
||||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -1034,7 +1034,7 @@ class EventCreationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
return await self._create_and_send_nonmember_event_locked(
|
return await self._create_and_send_nonmember_event_locked(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
|
@ -1978,7 +1978,7 @@ class EventCreationHandler:
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
|
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.events.utils import SerializeEventConfig
|
from synapse.events.utils import SerializeEventConfig
|
||||||
from synapse.handlers.room import ShutdownRoomResponse
|
from synapse.handlers.room import ShutdownRoomResponse
|
||||||
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.rest.admin._base import assert_user_is_admin
|
from synapse.rest.admin._base import assert_user_is_admin
|
||||||
|
@ -46,9 +47,10 @@ logger = logging.getLogger(__name__)
|
||||||
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
|
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
|
||||||
|
|
||||||
|
|
||||||
PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
|
# This is used to avoid purging a room several time at the same moment,
|
||||||
|
# and also paginating during a purge. Pagination can trigger backfill,
|
||||||
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
|
# which would create old events locally, and would potentially clash with the room delete.
|
||||||
|
PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock"
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
@ -363,7 +365,7 @@ class PaginationHandler:
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
async with self._worker_locks.acquire_read_write_lock(
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
PURGE_HISTORY_LOCK_NAME, room_id, write=True
|
PURGE_PAGINATION_LOCK_NAME, room_id, write=True
|
||||||
):
|
):
|
||||||
await self._storage_controllers.purge_events.purge_history(
|
await self._storage_controllers.purge_events.purge_history(
|
||||||
room_id, token, delete_local_events
|
room_id, token, delete_local_events
|
||||||
|
@ -421,7 +423,10 @@ class PaginationHandler:
|
||||||
force: set true to skip checking for joined users.
|
force: set true to skip checking for joined users.
|
||||||
"""
|
"""
|
||||||
async with self._worker_locks.acquire_multi_read_write_lock(
|
async with self._worker_locks.acquire_multi_read_write_lock(
|
||||||
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
|
[
|
||||||
|
(PURGE_PAGINATION_LOCK_NAME, room_id),
|
||||||
|
(NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id),
|
||||||
|
],
|
||||||
write=True,
|
write=True,
|
||||||
):
|
):
|
||||||
# first check that we have no users in this room
|
# first check that we have no users in this room
|
||||||
|
@ -483,7 +488,7 @@ class PaginationHandler:
|
||||||
room_token = from_token.room_key
|
room_token = from_token.room_key
|
||||||
|
|
||||||
async with self._worker_locks.acquire_read_write_lock(
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
PURGE_HISTORY_LOCK_NAME, room_id, write=False
|
PURGE_PAGINATION_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
(membership, member_event_id) = (None, None)
|
(membership, member_event_id) = (None, None)
|
||||||
if not use_admin_priviledge:
|
if not use_admin_priviledge:
|
||||||
|
@ -761,7 +766,7 @@ class PaginationHandler:
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
async with self._worker_locks.acquire_read_write_lock(
|
async with self._worker_locks.acquire_read_write_lock(
|
||||||
PURGE_HISTORY_LOCK_NAME, room_id, write=True
|
PURGE_PAGINATION_LOCK_NAME, room_id, write=True
|
||||||
):
|
):
|
||||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
||||||
self._delete_by_id[
|
self._delete_by_id[
|
||||||
|
|
|
@ -30,9 +30,9 @@ from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -44,7 +44,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from typing_extensions import ContextManager
|
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
|
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
|
||||||
|
@ -54,7 +53,10 @@ from synapse.appservice import ApplicationService
|
||||||
from synapse.events.presence_router import PresenceRouter
|
from synapse.events.presence_router import PresenceRouter
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
run_as_background_process,
|
||||||
|
wrap_as_background_process,
|
||||||
|
)
|
||||||
from synapse.replication.http.presence import (
|
from synapse.replication.http.presence import (
|
||||||
ReplicationBumpPresenceActiveTime,
|
ReplicationBumpPresenceActiveTime,
|
||||||
ReplicationPresenceSetState,
|
ReplicationPresenceSetState,
|
||||||
|
@ -141,6 +143,8 @@ class BasePresenceHandler(abc.ABC):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
self._presence_enabled = hs.config.server.use_presence
|
||||||
|
|
||||||
self._federation = None
|
self._federation = None
|
||||||
if hs.should_send_federation():
|
if hs.should_send_federation():
|
||||||
self._federation = hs.get_federation_sender()
|
self._federation = hs.get_federation_sender()
|
||||||
|
@ -149,6 +153,15 @@ class BasePresenceHandler(abc.ABC):
|
||||||
|
|
||||||
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
|
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
|
||||||
|
|
||||||
|
self.VALID_PRESENCE: Tuple[str, ...] = (
|
||||||
|
PresenceState.ONLINE,
|
||||||
|
PresenceState.UNAVAILABLE,
|
||||||
|
PresenceState.OFFLINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._busy_presence_enabled:
|
||||||
|
self.VALID_PRESENCE += (PresenceState.BUSY,)
|
||||||
|
|
||||||
active_presence = self.store.take_presence_startup_info()
|
active_presence = self.store.take_presence_startup_info()
|
||||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||||
|
|
||||||
|
@ -395,8 +408,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
self._presence_writer_instance = hs.config.worker.writers.presence[0]
|
self._presence_writer_instance = hs.config.worker.writers.presence[0]
|
||||||
|
|
||||||
self._presence_enabled = hs.config.server.use_presence
|
|
||||||
|
|
||||||
# Route presence EDUs to the right worker
|
# Route presence EDUs to the right worker
|
||||||
hs.get_federation_registry().register_instances_for_edu(
|
hs.get_federation_registry().register_instances_for_edu(
|
||||||
EduTypes.PRESENCE,
|
EduTypes.PRESENCE,
|
||||||
|
@ -421,8 +432,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
|
|
||||||
|
|
||||||
hs.get_reactor().addSystemEventTrigger(
|
hs.get_reactor().addSystemEventTrigger(
|
||||||
"before",
|
"before",
|
||||||
"shutdown",
|
"shutdown",
|
||||||
|
@ -490,7 +499,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
# what the spec wants: see comment in the BasePresenceHandler version
|
# what the spec wants: see comment in the BasePresenceHandler version
|
||||||
# of this function.
|
# of this function.
|
||||||
await self.set_state(
|
await self.set_state(
|
||||||
UserID.from_string(user_id), {"presence": presence_state}, True
|
UserID.from_string(user_id),
|
||||||
|
{"presence": presence_state},
|
||||||
|
ignore_status_msg=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
|
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
|
||||||
|
@ -601,22 +612,13 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
"""
|
"""
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
valid_presence = (
|
if presence not in self.VALID_PRESENCE:
|
||||||
PresenceState.ONLINE,
|
|
||||||
PresenceState.UNAVAILABLE,
|
|
||||||
PresenceState.OFFLINE,
|
|
||||||
PresenceState.BUSY,
|
|
||||||
)
|
|
||||||
|
|
||||||
if presence not in valid_presence or (
|
|
||||||
presence == PresenceState.BUSY and not self._busy_presence_enabled
|
|
||||||
):
|
|
||||||
raise SynapseError(400, "Invalid presence state")
|
raise SynapseError(400, "Invalid presence state")
|
||||||
|
|
||||||
user_id = target_user.to_string()
|
user_id = target_user.to_string()
|
||||||
|
|
||||||
# If presence is disabled, no-op
|
# If presence is disabled, no-op
|
||||||
if not self.hs.config.server.use_presence:
|
if not self._presence_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Proxy request to instance that writes presence
|
# Proxy request to instance that writes presence
|
||||||
|
@ -633,7 +635,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
with the app.
|
with the app.
|
||||||
"""
|
"""
|
||||||
# If presence is disabled, no-op
|
# If presence is disabled, no-op
|
||||||
if not self.hs.config.server.use_presence:
|
if not self._presence_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Proxy request to instance that writes presence
|
# Proxy request to instance that writes presence
|
||||||
|
@ -649,7 +651,6 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.wheel_timer: WheelTimer[str] = WheelTimer()
|
self.wheel_timer: WheelTimer[str] = WheelTimer()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self._presence_enabled = hs.config.server.use_presence
|
|
||||||
|
|
||||||
federation_registry = hs.get_federation_registry()
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
@ -700,8 +701,6 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
self._on_shutdown,
|
self._on_shutdown,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._next_serial = 1
|
|
||||||
|
|
||||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||||
# this is non zero a user will never go offline.
|
# this is non zero a user will never go offline.
|
||||||
self.user_to_num_current_syncs: Dict[str, int] = {}
|
self.user_to_num_current_syncs: Dict[str, int] = {}
|
||||||
|
@ -723,21 +722,16 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
# Start a LoopingCall in 30s that fires every 5s.
|
# Start a LoopingCall in 30s that fires every 5s.
|
||||||
# The initial delay is to allow disconnected clients a chance to
|
# The initial delay is to allow disconnected clients a chance to
|
||||||
# reconnect before we treat them as offline.
|
# reconnect before we treat them as offline.
|
||||||
def run_timeout_handler() -> Awaitable[None]:
|
|
||||||
return run_as_background_process(
|
|
||||||
"handle_presence_timeouts", self._handle_timeouts
|
|
||||||
)
|
|
||||||
|
|
||||||
self.clock.call_later(
|
self.clock.call_later(
|
||||||
30, self.clock.looping_call, run_timeout_handler, 5000
|
30, self.clock.looping_call, self._handle_timeouts, 5000
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_persister() -> Awaitable[None]:
|
self.clock.call_later(
|
||||||
return run_as_background_process(
|
60,
|
||||||
"persist_presence_changes", self._persist_unpersisted_changes
|
self.clock.looping_call,
|
||||||
)
|
self._persist_unpersisted_changes,
|
||||||
|
60 * 1000,
|
||||||
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
|
)
|
||||||
|
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_handlers_presence_wheel_timer_size",
|
"synapse_handlers_presence_wheel_timer_size",
|
||||||
|
@ -783,6 +777,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
)
|
)
|
||||||
logger.info("Finished _on_shutdown")
|
logger.info("Finished _on_shutdown")
|
||||||
|
|
||||||
|
@wrap_as_background_process("persist_presence_changes")
|
||||||
async def _persist_unpersisted_changes(self) -> None:
|
async def _persist_unpersisted_changes(self) -> None:
|
||||||
"""We periodically persist the unpersisted changes, as otherwise they
|
"""We periodically persist the unpersisted changes, as otherwise they
|
||||||
may stack up and slow down shutdown times.
|
may stack up and slow down shutdown times.
|
||||||
|
@ -898,6 +893,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
states, [destination]
|
states, [destination]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@wrap_as_background_process("handle_presence_timeouts")
|
||||||
async def _handle_timeouts(self) -> None:
|
async def _handle_timeouts(self) -> None:
|
||||||
"""Checks the presence of users that have timed out and updates as
|
"""Checks the presence of users that have timed out and updates as
|
||||||
appropriate.
|
appropriate.
|
||||||
|
@ -955,7 +951,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
with the app.
|
with the app.
|
||||||
"""
|
"""
|
||||||
# If presence is disabled, no-op
|
# If presence is disabled, no-op
|
||||||
if not self.hs.config.server.use_presence:
|
if not self._presence_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
@ -990,56 +986,51 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
client that is being used by a user.
|
client that is being used by a user.
|
||||||
presence_state: The presence state indicated in the sync request
|
presence_state: The presence state indicated in the sync request
|
||||||
"""
|
"""
|
||||||
# Override if it should affect the user's presence, if presence is
|
if not affect_presence or not self._presence_enabled:
|
||||||
# disabled.
|
return _NullContextManager()
|
||||||
if not self.hs.config.server.use_presence:
|
|
||||||
affect_presence = False
|
|
||||||
|
|
||||||
if affect_presence:
|
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
|
||||||
|
|
||||||
|
prev_state = await self.current_state_for_user(user_id)
|
||||||
|
|
||||||
|
# If they're busy then they don't stop being busy just by syncing,
|
||||||
|
# so just update the last sync time.
|
||||||
|
if prev_state.state != PresenceState.BUSY:
|
||||||
|
# XXX: We set_state separately here and just update the last_active_ts above
|
||||||
|
# This keeps the logic as similar as possible between the worker and single
|
||||||
|
# process modes. Using set_state will actually cause last_active_ts to be
|
||||||
|
# updated always, which is not what the spec calls for, but synapse has done
|
||||||
|
# this for... forever, I think.
|
||||||
|
await self.set_state(
|
||||||
|
UserID.from_string(user_id),
|
||||||
|
{"presence": presence_state},
|
||||||
|
ignore_status_msg=True,
|
||||||
|
)
|
||||||
|
# Retrieve the new state for the logic below. This should come from the
|
||||||
|
# in-memory cache.
|
||||||
prev_state = await self.current_state_for_user(user_id)
|
prev_state = await self.current_state_for_user(user_id)
|
||||||
|
|
||||||
# If they're busy then they don't stop being busy just by syncing,
|
# To keep the single process behaviour consistent with worker mode, run the
|
||||||
# so just update the last sync time.
|
# same logic as `update_external_syncs_row`, even though it looks weird.
|
||||||
if prev_state.state != PresenceState.BUSY:
|
if prev_state.state == PresenceState.OFFLINE:
|
||||||
# XXX: We set_state separately here and just update the last_active_ts above
|
await self._update_states(
|
||||||
# This keeps the logic as similar as possible between the worker and single
|
[
|
||||||
# process modes. Using set_state will actually cause last_active_ts to be
|
prev_state.copy_and_replace(
|
||||||
# updated always, which is not what the spec calls for, but synapse has done
|
state=PresenceState.ONLINE,
|
||||||
# this for... forever, I think.
|
last_active_ts=self.clock.time_msec(),
|
||||||
await self.set_state(
|
last_user_sync_ts=self.clock.time_msec(),
|
||||||
UserID.from_string(user_id), {"presence": presence_state}, True
|
)
|
||||||
)
|
]
|
||||||
# Retrieve the new state for the logic below. This should come from the
|
)
|
||||||
# in-memory cache.
|
# otherwise, set the new presence state & update the last sync time,
|
||||||
prev_state = await self.current_state_for_user(user_id)
|
# but don't update last_active_ts as this isn't an indication that
|
||||||
|
# they've been active (even though it's probably been updated by
|
||||||
# To keep the single process behaviour consistent with worker mode, run the
|
# set_state above)
|
||||||
# same logic as `update_external_syncs_row`, even though it looks weird.
|
else:
|
||||||
if prev_state.state == PresenceState.OFFLINE:
|
await self._update_states(
|
||||||
await self._update_states(
|
[prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())]
|
||||||
[
|
)
|
||||||
prev_state.copy_and_replace(
|
|
||||||
state=PresenceState.ONLINE,
|
|
||||||
last_active_ts=self.clock.time_msec(),
|
|
||||||
last_user_sync_ts=self.clock.time_msec(),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# otherwise, set the new presence state & update the last sync time,
|
|
||||||
# but don't update last_active_ts as this isn't an indication that
|
|
||||||
# they've been active (even though it's probably been updated by
|
|
||||||
# set_state above)
|
|
||||||
else:
|
|
||||||
await self._update_states(
|
|
||||||
[
|
|
||||||
prev_state.copy_and_replace(
|
|
||||||
last_user_sync_ts=self.clock.time_msec()
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _end() -> None:
|
async def _end() -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -1061,8 +1052,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if affect_presence:
|
run_in_background(_end)
|
||||||
run_in_background(_end)
|
|
||||||
|
|
||||||
return _user_syncing()
|
return _user_syncing()
|
||||||
|
|
||||||
|
@ -1229,20 +1219,11 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
status_msg = state.get("status_msg", None)
|
status_msg = state.get("status_msg", None)
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
valid_presence = (
|
if presence not in self.VALID_PRESENCE:
|
||||||
PresenceState.ONLINE,
|
|
||||||
PresenceState.UNAVAILABLE,
|
|
||||||
PresenceState.OFFLINE,
|
|
||||||
PresenceState.BUSY,
|
|
||||||
)
|
|
||||||
|
|
||||||
if presence not in valid_presence or (
|
|
||||||
presence == PresenceState.BUSY and not self._busy_presence_enabled
|
|
||||||
):
|
|
||||||
raise SynapseError(400, "Invalid presence state")
|
raise SynapseError(400, "Invalid presence state")
|
||||||
|
|
||||||
# If presence is disabled, no-op
|
# If presence is disabled, no-op
|
||||||
if not self.hs.config.server.use_presence:
|
if not self._presence_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
user_id = target_user.to_string()
|
user_id = target_user.to_string()
|
||||||
|
|
|
@ -39,7 +39,7 @@ from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||||
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
|
||||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.metrics import event_processing_positions
|
from synapse.metrics import event_processing_positions
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -629,7 +629,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
async with self.member_linearizer.queue(key):
|
async with self.member_linearizer.queue(key):
|
||||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
diff = self.clock.time_msec() - then
|
diff = self.clock.time_msec() - then
|
||||||
|
|
||||||
|
|
|
@ -74,12 +74,13 @@ class SamlHandler:
|
||||||
self.idp_id = "saml"
|
self.idp_id = "saml"
|
||||||
|
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = "SAML"
|
self.idp_name = hs.config.saml2.idp_name
|
||||||
|
|
||||||
# we do not currently support icons/brands for SAML auth, but this is required by
|
# MXC URI for icon for this auth provider
|
||||||
# the SsoIdentityProvider protocol type.
|
self.idp_icon = hs.config.saml2.idp_icon
|
||||||
self.idp_icon = None
|
|
||||||
self.idp_brand = None
|
# optional brand identifier for this auth provider
|
||||||
|
self.idp_brand = hs.config.saml2.idp_brand
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
|
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
|
||||||
|
|
|
@ -24,13 +24,14 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
)
|
)
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import NoReturn, Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from twisted.web.iweb import IRequest
|
from twisted.web.iweb import IRequest
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -791,7 +792,7 @@ class SsoHandler:
|
||||||
|
|
||||||
if code != 200:
|
if code != 200:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"GET request to download sso avatar image returned {}".format(code)
|
f"GET request to download sso avatar image returned {code}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# upload name includes hash of the image file's content so that we can
|
# upload name includes hash of the image file's content so that we can
|
||||||
|
|
|
@ -14,9 +14,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
from typing_extensions import Counter as CounterType
|
Any,
|
||||||
|
Counter as CounterType,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||||
from synapse.metrics import event_processing_positions
|
from synapse.metrics import event_processing_positions
|
||||||
|
|
|
@ -387,16 +387,16 @@ class SyncHandler:
|
||||||
from_token=since_token,
|
from_token=since_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if nothing has happened in any of the users' rooms since /sync was called,
|
# if nothing has happened in any of the users' rooms since /sync was called,
|
||||||
# the resultant next_batch will be the same as since_token (since the result
|
# the resultant next_batch will be the same as since_token (since the result
|
||||||
# is generated when wait_for_events is first called, and not regenerated
|
# is generated when wait_for_events is first called, and not regenerated
|
||||||
# when wait_for_events times out).
|
# when wait_for_events times out).
|
||||||
#
|
#
|
||||||
# If that happens, we mustn't cache it, so that when the client comes back
|
# If that happens, we mustn't cache it, so that when the client comes back
|
||||||
# with the same cache token, we don't immediately return the same empty
|
# with the same cache token, we don't immediately return the same empty
|
||||||
# result, causing a tightloop. (#8518)
|
# result, causing a tightloop. (#8518)
|
||||||
if result.next_batch == since_token:
|
if result.next_batch == since_token:
|
||||||
cache_context.should_cache = False
|
cache_context.should_cache = False
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if sync_config.filter_collection.lazy_load_members():
|
if sync_config.filter_collection.lazy_load_members():
|
||||||
|
@ -1442,11 +1442,9 @@ class SyncHandler:
|
||||||
|
|
||||||
# Now we have our list of joined room IDs, exclude as configured and freeze
|
# Now we have our list of joined room IDs, exclude as configured and freeze
|
||||||
joined_room_ids = frozenset(
|
joined_room_ids = frozenset(
|
||||||
(
|
room_id
|
||||||
room_id
|
for room_id in mutable_joined_room_ids
|
||||||
for room_id in mutable_joined_room_ids
|
if room_id not in mutable_rooms_to_exclude
|
||||||
if room_id not in mutable_rooms_to_exclude
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.update_user_directory = hs.config.worker.should_update_user_directory
|
self.update_user_directory = hs.config.worker.should_update_user_directory
|
||||||
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
|
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
|
||||||
|
self.show_locked_users = hs.config.userdirectory.show_locked_users
|
||||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||||
self._hs = hs
|
self._hs = hs
|
||||||
|
|
||||||
|
@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
results = await self.store.search_user_dir(user_id, search_term, limit)
|
results = await self.store.search_user_dir(
|
||||||
|
user_id, search_term, limit, self.show_locked_users
|
||||||
|
)
|
||||||
|
|
||||||
# Remove any spammy users from the results.
|
# Remove any spammy users from the results.
|
||||||
non_spammy_users = []
|
non_spammy_users = []
|
||||||
|
|
|
@ -42,7 +42,11 @@ if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
DELETE_ROOM_LOCK_NAME = "delete_room_lock"
|
# This lock is used to avoid creating an event while we are purging the room.
|
||||||
|
# We take a read lock when creating an event, and a write one when purging a room.
|
||||||
|
# This is because it is fine to create several events concurrently, since referenced events
|
||||||
|
# will not disappear under our feet as long as we don't delete the room.
|
||||||
|
NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
|
||||||
|
|
||||||
|
|
||||||
class WorkerLocksHandler:
|
class WorkerLocksHandler:
|
||||||
|
|
|
@ -18,10 +18,9 @@ import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||||
from math import floor
|
from math import floor
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Deque, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Deque
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.application.internet import ClientService
|
from twisted.application.internet import ClientService
|
||||||
|
|
|
@ -31,7 +31,7 @@ from typing import (
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import jinja2
|
import jinja2
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.interfaces import IDelayedCall
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
@ -885,7 +885,7 @@ class ModuleApi:
|
||||||
def run_db_interaction(
|
def run_db_interaction(
|
||||||
self,
|
self,
|
||||||
desc: str,
|
desc: str,
|
||||||
func: Callable[P, T],
|
func: Callable[Concatenate[LoggingTransaction, P], T],
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
**kwargs: P.kwargs,
|
**kwargs: P.kwargs,
|
||||||
) -> "defer.Deferred[T]":
|
) -> "defer.Deferred[T]":
|
||||||
|
|
|
@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
generally discouraged as it doesn't support internationalization.
|
generally discouraged as it doesn't support internationalization.
|
||||||
"""
|
"""
|
||||||
for callback in self._check_event_for_spam_callbacks:
|
for callback in self._check_event_for_spam_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(event))
|
res = await delay_cancellation(callback(event))
|
||||||
if res is False or res == self.NOT_SPAM:
|
if res is False or res == self.NOT_SPAM:
|
||||||
# This spam-checker accepts the event.
|
# This spam-checker accepts the event.
|
||||||
|
@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
True if the event should be silently dropped
|
True if the event should be silently dropped
|
||||||
"""
|
"""
|
||||||
for callback in self._should_drop_federated_event_callbacks:
|
for callback in self._should_drop_federated_event_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res: Union[bool, str] = await delay_cancellation(callback(event))
|
res: Union[bool, str] = await delay_cancellation(callback(event))
|
||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
|
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_join_room_callbacks:
|
for callback in self._user_may_join_room_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(user_id, room_id, is_invited))
|
res = await delay_cancellation(callback(user_id, room_id, is_invited))
|
||||||
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
||||||
if res is True or res is self.NOT_SPAM:
|
if res is True or res is self.NOT_SPAM:
|
||||||
|
@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
NOT_SPAM if the operation is permitted, Codes otherwise.
|
NOT_SPAM if the operation is permitted, Codes otherwise.
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_invite_callbacks:
|
for callback in self._user_may_invite_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(
|
res = await delay_cancellation(
|
||||||
callback(inviter_userid, invitee_userid, room_id)
|
callback(inviter_userid, invitee_userid, room_id)
|
||||||
)
|
)
|
||||||
|
@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
NOT_SPAM if the operation is permitted, Codes otherwise.
|
NOT_SPAM if the operation is permitted, Codes otherwise.
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_send_3pid_invite_callbacks:
|
for callback in self._user_may_send_3pid_invite_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(
|
res = await delay_cancellation(
|
||||||
callback(inviter_userid, medium, address, room_id)
|
callback(inviter_userid, medium, address, room_id)
|
||||||
)
|
)
|
||||||
|
@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
userid: The ID of the user attempting to create a room
|
userid: The ID of the user attempting to create a room
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_create_room_callbacks:
|
for callback in self._user_may_create_room_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(userid))
|
res = await delay_cancellation(callback(userid))
|
||||||
if res is True or res is self.NOT_SPAM:
|
if res is True or res is self.NOT_SPAM:
|
||||||
continue
|
continue
|
||||||
|
@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_create_room_alias_callbacks:
|
for callback in self._user_may_create_room_alias_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(userid, room_alias))
|
res = await delay_cancellation(callback(userid, room_alias))
|
||||||
if res is True or res is self.NOT_SPAM:
|
if res is True or res is self.NOT_SPAM:
|
||||||
continue
|
continue
|
||||||
|
@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
room_id: The ID of the room that would be published
|
room_id: The ID of the room that would be published
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_publish_room_callbacks:
|
for callback in self._user_may_publish_room_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(userid, room_id))
|
res = await delay_cancellation(callback(userid, room_id))
|
||||||
if res is True or res is self.NOT_SPAM:
|
if res is True or res is self.NOT_SPAM:
|
||||||
continue
|
continue
|
||||||
|
@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
True if the user is spammy.
|
True if the user is spammy.
|
||||||
"""
|
"""
|
||||||
for callback in self._check_username_for_spam_callbacks:
|
for callback in self._check_username_for_spam_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
# Make a copy of the user profile object to ensure the spam checker cannot
|
# Make a copy of the user profile object to ensure the spam checker cannot
|
||||||
# modify it.
|
# modify it.
|
||||||
res = await delay_cancellation(callback(user_profile.copy()))
|
res = await delay_cancellation(callback(user_profile.copy()))
|
||||||
|
@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in self._check_registration_for_spam_callbacks:
|
for callback in self._check_registration_for_spam_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
behaviour = await delay_cancellation(
|
behaviour = await delay_cancellation(
|
||||||
callback(email_threepid, username, request_info, auth_provider_id)
|
callback(email_threepid, username, request_info, auth_provider_id)
|
||||||
)
|
)
|
||||||
|
@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in self._check_media_file_for_spam_callbacks:
|
for callback in self._check_media_file_for_spam_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(callback(file_wrapper, file_info))
|
res = await delay_cancellation(callback(file_wrapper, file_info))
|
||||||
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
||||||
if res is False or res is self.NOT_SPAM:
|
if res is False or res is self.NOT_SPAM:
|
||||||
|
@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for callback in self._check_login_for_spam_callbacks:
|
for callback in self._check_login_for_spam_callbacks:
|
||||||
with Measure(
|
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
|
||||||
):
|
|
||||||
res = await delay_cancellation(
|
res = await delay_cancellation(
|
||||||
callback(
|
callback(
|
||||||
user_id,
|
user_id,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
|
Deque,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
|
@ -29,7 +30,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from typing_extensions import Deque
|
|
||||||
|
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
|
|
|
@ -280,6 +280,17 @@ class UserRestServletV2(RestServlet):
|
||||||
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
|
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lock = body.get("locked", False)
|
||||||
|
if not isinstance(lock, bool):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean"
|
||||||
|
)
|
||||||
|
|
||||||
|
if deactivate and lock:
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked"
|
||||||
|
)
|
||||||
|
|
||||||
approved: Optional[bool] = None
|
approved: Optional[bool] = None
|
||||||
if "approved" in body and self._msc3866_enabled:
|
if "approved" in body and self._msc3866_enabled:
|
||||||
approved = body["approved"]
|
approved = body["approved"]
|
||||||
|
@ -397,6 +408,12 @@ class UserRestServletV2(RestServlet):
|
||||||
target_user.to_string()
|
target_user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "locked" in body:
|
||||||
|
if lock and not user["locked"]:
|
||||||
|
await self.store.set_user_locked_status(user_id, True)
|
||||||
|
elif not lock and user["locked"]:
|
||||||
|
await self.store.set_user_locked_status(user_id, False)
|
||||||
|
|
||||||
if "user_type" in body:
|
if "user_type" in body:
|
||||||
await self.store.set_user_type(target_user, user_type)
|
await self.store.set_user_type(target_user, user_type)
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,6 @@ from synapse.http.servlet import (
|
||||||
parse_integer,
|
parse_integer,
|
||||||
)
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
|
|
||||||
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||||
from synapse.rest.client.models import AuthenticationData
|
from synapse.rest.client.models import AuthenticationData
|
||||||
from synapse.rest.models import RequestBodyModel
|
from synapse.rest.models import RequestBodyModel
|
||||||
|
@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
|
||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
self.device_handler = handler
|
self.device_handler = handler
|
||||||
|
|
||||||
if hs.config.worker.worker_app is None:
|
|
||||||
# if main process
|
|
||||||
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
|
|
||||||
else:
|
|
||||||
# then a worker
|
|
||||||
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
|
|
||||||
|
|
||||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
|
||||||
"Device key(s) not found, these must be provided.",
|
"Device key(s) not found, these must be provided.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Those two operations, creating a device and storing the
|
|
||||||
# device's keys should be atomic.
|
|
||||||
device_id = await self.device_handler.store_dehydrated_device(
|
device_id = await self.device_handler.store_dehydrated_device(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
submission.device_id,
|
submission.device_id,
|
||||||
submission.device_data.dict(),
|
submission.device_data.dict(),
|
||||||
submission.initial_device_display_name,
|
submission.initial_device_display_name,
|
||||||
)
|
device_info,
|
||||||
|
|
||||||
# TODO: Do we need to do something with the result here?
|
|
||||||
await self.key_uploader(
|
|
||||||
user_id=user_id, device_id=submission.device_id, keys=submission.dict()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {"device_id": device_id}
|
return 200, {"device_id": device_id}
|
||||||
|
|
|
@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet):
|
||||||
self._device_handler = handler
|
self._device_handler = handler
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
requester = await self.auth.get_user_by_req(
|
||||||
|
request, allow_expired=True, allow_locked=True
|
||||||
|
)
|
||||||
|
|
||||||
if requester.device_id is None:
|
if requester.device_id is None:
|
||||||
# The access token wasn't associated with a device.
|
# The access token wasn't associated with a device.
|
||||||
|
@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet):
|
||||||
self._device_handler = handler
|
self._device_handler = handler
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_expired=True)
|
requester = await self.auth.get_user_by_req(
|
||||||
|
request, allow_expired=True, allow_locked=True
|
||||||
|
)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
# first delete all of the user's devices
|
# first delete all of the user's devices
|
||||||
|
|
|
@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
|
||||||
from synapse.rest.client._base import client_patterns
|
from synapse.rest.client._base import client_patterns
|
||||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self._is_worker = hs.config.worker.worker_app is not None
|
self._is_worker = hs.config.worker.worker_app is not None
|
||||||
self._push_rules_handler = hs.get_push_rules_handler()
|
self._push_rules_handler = hs.get_push_rules_handler()
|
||||||
|
self._push_rule_linearizer = Linearizer(name="push_rules")
|
||||||
|
|
||||||
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
|
||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle PUT /push_rules on worker")
|
raise Exception("Cannot handle PUT /push_rules on worker")
|
||||||
|
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
async with self._push_rule_linearizer.queue(user_id):
|
||||||
|
return await self.handle_put(request, path, user_id)
|
||||||
|
|
||||||
|
async def handle_put(
|
||||||
|
self, request: SynapseRequest, path: str, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
spec = _rule_spec_from_path(path.split("/"))
|
spec = _rule_spec_from_path(path.split("/"))
|
||||||
try:
|
try:
|
||||||
priority_class = _priority_class_from_spec(spec)
|
priority_class = _priority_class_from_spec(spec)
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, str(e))
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
requester = await self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
if "/" in spec.rule_id or "\\" in spec.rule_id:
|
if "/" in spec.rule_id or "\\" in spec.rule_id:
|
||||||
raise SynapseError(400, "rule_id may not contain slashes")
|
raise SynapseError(400, "rule_id may not contain slashes")
|
||||||
|
|
||||||
content = parse_json_value_from_request(request)
|
content = parse_json_value_from_request(request)
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
|
||||||
|
|
||||||
if spec.attr:
|
if spec.attr:
|
||||||
try:
|
try:
|
||||||
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
|
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
|
||||||
|
@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet):
|
||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle DELETE /push_rules on worker")
|
raise Exception("Cannot handle DELETE /push_rules on worker")
|
||||||
|
|
||||||
spec = _rule_spec_from_path(path.split("/"))
|
|
||||||
|
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
async with self._push_rule_linearizer.queue(user_id):
|
||||||
|
return await self.handle_delete(request, path, user_id)
|
||||||
|
|
||||||
|
async def handle_delete(
|
||||||
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
path: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
spec = _rule_spec_from_path(path.split("/"))
|
||||||
|
|
||||||
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
|
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||||
requester, room_id, new_version
|
requester, room_id, new_version
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from synapse.http.servlet import (
|
||||||
parse_integer,
|
parse_integer,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.keys import FetchKeyResultForRemote
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
|
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
logger.info("Handling query for keys %r", query)
|
logger.info("Handling query for keys %r", query)
|
||||||
|
|
||||||
store_queries = []
|
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
if not key_ids:
|
if key_ids:
|
||||||
key_ids = (None,)
|
results: Mapping[
|
||||||
for key_id in key_ids:
|
str, Optional[FetchKeyResultForRemote]
|
||||||
store_queries.append((server_name, key_id, None))
|
] = await self.store.get_server_keys_json_for_remote(
|
||||||
|
server_name, key_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = await self.store.get_all_server_keys_json_for_remote(
|
||||||
|
server_name
|
||||||
|
)
|
||||||
|
|
||||||
cached = await self.store.get_server_keys_json_for_remote(store_queries)
|
server_keys.update(
|
||||||
|
((server_name, key_id), res) for key_id, res in results.items()
|
||||||
|
)
|
||||||
|
|
||||||
json_results: Set[bytes] = set()
|
json_results: Set[bytes] = set()
|
||||||
|
|
||||||
|
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
|
||||||
# Map server_name->key_id->int. Note that the value of the int is unused.
|
# Map server_name->key_id->int. Note that the value of the int is unused.
|
||||||
# XXX: why don't we just use a set?
|
# XXX: why don't we just use a set?
|
||||||
cache_misses: Dict[str, Dict[str, int]] = {}
|
cache_misses: Dict[str, Dict[str, int]] = {}
|
||||||
for (server_name, key_id, _), key_results in cached.items():
|
for (server_name, key_id), key_result in server_keys.items():
|
||||||
results = [(result["ts_added_ms"], result) for result in key_results]
|
if not query[server_name]:
|
||||||
|
|
||||||
if key_id is None:
|
|
||||||
# all keys were requested. Just return what we have without worrying
|
# all keys were requested. Just return what we have without worrying
|
||||||
# about validity
|
# about validity
|
||||||
for _, result in results:
|
if key_result:
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
json_results.add(key_result.key_json)
|
||||||
json_results.add(bytes(result["key_json"]))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
miss = False
|
miss = False
|
||||||
if not results:
|
if key_result is None:
|
||||||
miss = True
|
miss = True
|
||||||
else:
|
else:
|
||||||
ts_added_ms, most_recent_result = max(results)
|
ts_added_ms = key_result.added_ts
|
||||||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
ts_valid_until_ms = key_result.valid_until_ts
|
||||||
req_key = query.get(server_name, {}).get(key_id, {})
|
req_key = query.get(server_name, {}).get(key_id, {})
|
||||||
req_valid_until = req_key.get("minimum_valid_until_ts")
|
req_valid_until = req_key.get("minimum_valid_until_ts")
|
||||||
if req_valid_until is not None:
|
if req_valid_until is not None:
|
||||||
|
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
|
||||||
ts_valid_until_ms,
|
ts_valid_until_ms,
|
||||||
time_now_ms,
|
time_now_ms,
|
||||||
)
|
)
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
|
||||||
json_results.add(bytes(most_recent_result["key_json"]))
|
json_results.add(key_result.key_json)
|
||||||
|
|
||||||
if miss and query_remote_on_cache_miss:
|
if miss and query_remote_on_cache_miss:
|
||||||
# only bother attempting to fetch keys from servers on our whitelist
|
# only bother attempting to fetch keys from servers on our whitelist
|
||||||
|
|
|
@ -238,6 +238,7 @@ class BackgroundUpdater:
|
||||||
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
|
self.hs = hs
|
||||||
|
|
||||||
self._database_name = database.name()
|
self._database_name = database.name()
|
||||||
|
|
||||||
|
@ -758,6 +759,11 @@ class BackgroundUpdater:
|
||||||
logger.debug("[SQL] %s", sql)
|
logger.debug("[SQL] %s", sql)
|
||||||
c.execute(sql)
|
c.execute(sql)
|
||||||
|
|
||||||
|
# override the global statement timeout to avoid accidentally squashing
|
||||||
|
# a long-running index creation process
|
||||||
|
timeout_sql = "SET SESSION statement_timeout = 0"
|
||||||
|
c.execute(timeout_sql)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
|
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
|
||||||
" ON %(table)s"
|
" ON %(table)s"
|
||||||
|
@ -778,6 +784,12 @@ class BackgroundUpdater:
|
||||||
logger.debug("[SQL] %s", sql)
|
logger.debug("[SQL] %s", sql)
|
||||||
c.execute(sql)
|
c.execute(sql)
|
||||||
finally:
|
finally:
|
||||||
|
# mypy ignore - `statement_timeout` is defined on PostgresEngine
|
||||||
|
# reset the global timeout to the default
|
||||||
|
default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined]
|
||||||
|
undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
|
||||||
|
conn.cursor().execute(undo_timeout_sql)
|
||||||
|
|
||||||
conn.set_session(autocommit=False) # type: ignore
|
conn.set_session(autocommit=False) # type: ignore
|
||||||
|
|
||||||
def create_index_sqlite(conn: Connection) -> None:
|
def create_index_sqlite(conn: Connection) -> None:
|
||||||
|
|
|
@ -45,7 +45,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
|
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
SynapseTags,
|
SynapseTags,
|
||||||
|
@ -357,7 +357,7 @@ class EventsPersistenceStorageController:
|
||||||
# it. We might already have taken out the lock, but since this is just a
|
# it. We might already have taken out the lock, but since this is just a
|
||||||
# "read" lock its inherently reentrant.
|
# "read" lock its inherently reentrant.
|
||||||
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
|
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
|
||||||
DELETE_ROOM_LOCK_NAME, room_id, write=False
|
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||||
):
|
):
|
||||||
if isinstance(task, _PersistEventsTask):
|
if isinstance(task, _PersistEventsTask):
|
||||||
return await self._persist_event_batch(room_id, task)
|
return await self._persist_event_batch(room_id, task)
|
||||||
|
|
|
@ -28,6 +28,7 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes
|
from synapse.api.constants import EduTypes
|
||||||
|
@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _store_dehydrated_device_txn(
|
def _store_dehydrated_device_txn(
|
||||||
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
device_data: str,
|
||||||
|
time: int,
|
||||||
|
keys: Optional[JsonDict] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
# TODO: make keys non-optional once support for msc2697 is dropped
|
||||||
|
if keys:
|
||||||
|
device_keys = keys.get("device_keys", None)
|
||||||
|
if device_keys:
|
||||||
|
# Type ignore - this function is defined on EndToEndKeyStore which we do
|
||||||
|
# have access to due to hs.get_datastore() "magic"
|
||||||
|
self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
|
||||||
|
txn, user_id, device_id, time, device_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
one_time_keys = keys.get("one_time_keys", None)
|
||||||
|
if one_time_keys:
|
||||||
|
key_list = []
|
||||||
|
for key_id, key_obj in one_time_keys.items():
|
||||||
|
algorithm, key_id = key_id.split(":")
|
||||||
|
key_list.append(
|
||||||
|
(
|
||||||
|
algorithm,
|
||||||
|
key_id,
|
||||||
|
encode_canonical_json(key_obj).decode("ascii"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
|
||||||
|
|
||||||
|
fallback_keys = keys.get("fallback_keys", None)
|
||||||
|
if fallback_keys:
|
||||||
|
self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
|
||||||
|
|
||||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table="dehydrated_devices",
|
table="dehydrated_devices",
|
||||||
|
@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
values={"device_id": device_id, "device_data": device_data},
|
values={"device_id": device_id, "device_data": device_data},
|
||||||
)
|
)
|
||||||
|
|
||||||
return old_device_id
|
return old_device_id
|
||||||
|
|
||||||
async def store_dehydrated_device(
|
async def store_dehydrated_device(
|
||||||
self, user_id: str, device_id: str, device_data: JsonDict
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
device_data: JsonDict,
|
||||||
|
time_now: int,
|
||||||
|
keys: Optional[dict] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Store a dehydrated device for a user.
|
"""Store a dehydrated device for a user.
|
||||||
|
|
||||||
|
@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
user_id: the user that we are storing the device for
|
user_id: the user that we are storing the device for
|
||||||
device_id: the ID of the dehydrated device
|
device_id: the ID of the dehydrated device
|
||||||
device_data: the dehydrated device information
|
device_data: the dehydrated device information
|
||||||
|
time_now: current time at the request in milliseconds
|
||||||
|
keys: keys for the dehydrated device
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
device id of the user's previous dehydrated device, if any
|
device id of the user's previous dehydrated device, if any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"store_dehydrated_device_txn",
|
"store_dehydrated_device_txn",
|
||||||
self._store_dehydrated_device_txn,
|
self._store_dehydrated_device_txn,
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
json_encoder.encode(device_data),
|
json_encoder.encode(device_data),
|
||||||
|
time_now,
|
||||||
|
keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
|
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
|
||||||
|
|
|
@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
|
|
||||||
set_tag("user_id", user_id)
|
|
||||||
set_tag("device_id", device_id)
|
|
||||||
set_tag("new_keys", str(new_keys))
|
|
||||||
# We are protected from race between lookup and insertion due to
|
|
||||||
# a unique constraint. If there is a race of two calls to
|
|
||||||
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
|
||||||
# insert one set.
|
|
||||||
self.db_pool.simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="e2e_one_time_keys_json",
|
|
||||||
keys=(
|
|
||||||
"user_id",
|
|
||||||
"device_id",
|
|
||||||
"algorithm",
|
|
||||||
"key_id",
|
|
||||||
"ts_added_ms",
|
|
||||||
"key_json",
|
|
||||||
),
|
|
||||||
values=[
|
|
||||||
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
|
|
||||||
for algorithm, key_id, json_bytes in new_keys
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
"add_e2e_one_time_keys_insert",
|
||||||
|
self._add_e2e_one_time_keys_txn,
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
time_now,
|
||||||
|
new_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_e2e_one_time_keys_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
time_now: int,
|
||||||
|
new_keys: Iterable[Tuple[str, str, str]],
|
||||||
|
) -> None:
|
||||||
|
"""Insert some new one time keys for a device. Errors if any of the keys already exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: id of user to get keys for
|
||||||
|
device_id: id of device to get keys for
|
||||||
|
time_now: insertion time to record (ms since epoch)
|
||||||
|
new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
|
||||||
|
that the key JSON must be in canonical JSON form
|
||||||
|
"""
|
||||||
|
set_tag("user_id", user_id)
|
||||||
|
set_tag("device_id", device_id)
|
||||||
|
set_tag("new_keys", str(new_keys))
|
||||||
|
# We are protected from race between lookup and insertion due to
|
||||||
|
# a unique constraint. If there is a race of two calls to
|
||||||
|
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||||
|
# insert one set.
|
||||||
|
self.db_pool.simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
keys=(
|
||||||
|
"user_id",
|
||||||
|
"device_id",
|
||||||
|
"algorithm",
|
||||||
|
"key_id",
|
||||||
|
"ts_added_ms",
|
||||||
|
"key_json",
|
||||||
|
),
|
||||||
|
values=[
|
||||||
|
(user_id, device_id, algorithm, key_id, time_now, json_bytes)
|
||||||
|
for algorithm, key_id, json_bytes in new_keys
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
|
@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
device_id: str,
|
device_id: str,
|
||||||
fallback_keys: JsonDict,
|
fallback_keys: JsonDict,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Set the user's e2e fallback keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user whose keys are being set
|
||||||
|
device_id: the device whose keys are being set
|
||||||
|
fallback_keys: the keys to set. This is a map from key ID (which is
|
||||||
|
of the form "algorithm:id") to key data.
|
||||||
|
"""
|
||||||
# fallback_keys will usually only have one item in it, so using a for
|
# fallback_keys will usually only have one item in it, so using a for
|
||||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||||
# FIXME: make sure that only one key per algorithm is uploaded
|
# FIXME: make sure that only one key per algorithm is uploaded
|
||||||
|
@ -1304,43 +1333,70 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Stores device keys for a device. Returns whether there was a change
|
"""Stores device keys for a device. Returns whether there was a change
|
||||||
or the keys were already in the database.
|
or the keys were already in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: user_id of the user to store keys for
|
||||||
|
device_id: device_id of the device to store keys for
|
||||||
|
time_now: time at the request to store the keys
|
||||||
|
device_keys: the keys to store
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
|
|
||||||
set_tag("user_id", user_id)
|
|
||||||
set_tag("device_id", device_id)
|
|
||||||
set_tag("time_now", time_now)
|
|
||||||
set_tag("device_keys", str(device_keys))
|
|
||||||
|
|
||||||
old_key_json = self.db_pool.simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="e2e_device_keys_json",
|
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
||||||
retcol="key_json",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# In py3 we need old_key_json to match new_key_json type. The DB
|
|
||||||
# returns unicode while encode_canonical_json returns bytes.
|
|
||||||
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
|
|
||||||
|
|
||||||
if old_key_json == new_key_json:
|
|
||||||
log_kv({"Message": "Device key already stored."})
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.db_pool.simple_upsert_txn(
|
|
||||||
txn,
|
|
||||||
table="e2e_device_keys_json",
|
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
||||||
values={"ts_added_ms": time_now, "key_json": new_key_json},
|
|
||||||
)
|
|
||||||
log_kv({"message": "Device keys stored."})
|
|
||||||
return True
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
"set_e2e_device_keys",
|
||||||
|
self._set_e2e_device_keys_txn,
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
time_now,
|
||||||
|
device_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _set_e2e_device_keys_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
time_now: int,
|
||||||
|
device_keys: JsonDict,
|
||||||
|
) -> bool:
|
||||||
|
"""Stores device keys for a device. Returns whether there was a change
|
||||||
|
or the keys were already in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: user_id of the user to store keys for
|
||||||
|
device_id: device_id of the device to store keys for
|
||||||
|
time_now: time at the request to store the keys
|
||||||
|
device_keys: the keys to store
|
||||||
|
"""
|
||||||
|
set_tag("user_id", user_id)
|
||||||
|
set_tag("device_id", device_id)
|
||||||
|
set_tag("time_now", time_now)
|
||||||
|
set_tag("device_keys", str(device_keys))
|
||||||
|
|
||||||
|
old_key_json = self.db_pool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
retcol="key_json",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# In py3 we need old_key_json to match new_key_json type. The DB
|
||||||
|
# returns unicode while encode_canonical_json returns bytes.
|
||||||
|
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
|
||||||
|
|
||||||
|
if old_key_json == new_key_json:
|
||||||
|
log_kv({"Message": "Device key already stored."})
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.db_pool.simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
values={"ts_added_ms": time_now, "key_json": new_key_json},
|
||||||
|
)
|
||||||
|
log_kv({"message": "Device keys stored."})
|
||||||
|
return True
|
||||||
|
|
||||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||||
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
||||||
log_kv(
|
log_kv(
|
||||||
|
|
|
@ -13,10 +13,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union, cast
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from typing_extensions import TYPE_CHECKING
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
|
|
@ -16,14 +16,13 @@
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Dict, Iterable, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
|
||||||
from synapse.storage.keys import FetchKeyResult
|
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||||
db_binary_type = memoryview
|
db_binary_type = memoryview
|
||||||
|
|
||||||
|
|
||||||
class KeyStore(SQLBaseStore):
|
class KeyStore(CacheInvalidationWorkerStore):
|
||||||
"""Persistence for signature verification keys"""
|
"""Persistence for signature verification keys"""
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
|
||||||
# invalidate takes a tuple corresponding to the params of
|
# invalidate takes a tuple corresponding to the params of
|
||||||
# _get_server_keys_json. _get_server_keys_json only takes one
|
# _get_server_keys_json. _get_server_keys_json only takes one
|
||||||
# param, which is itself the 2-tuple (server_name, key_id).
|
# param, which is itself the 2-tuple (server_name, key_id).
|
||||||
self._get_server_keys_json.invalidate((((server_name, key_id),)))
|
await self.invalidate_cache_and_stream(
|
||||||
|
"_get_server_keys_json", ((server_name, key_id),)
|
||||||
|
)
|
||||||
|
await self.invalidate_cache_and_stream(
|
||||||
|
"get_server_key_json_for_remote", (server_name, key_id)
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def _get_server_keys_json(
|
def _get_server_keys_json(
|
||||||
|
@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_server_key_json_for_remote(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
key_id: str,
|
||||||
|
) -> Optional[FetchKeyResultForRemote]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@cachedList(
|
||||||
|
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
|
||||||
|
)
|
||||||
async def get_server_keys_json_for_remote(
|
async def get_server_keys_json_for_remote(
|
||||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
self, server_name: str, key_ids: Iterable[str]
|
||||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
|
||||||
"""Retrieve the key json for a list of server_keys and key ids.
|
"""Fetch the cached keys for the given server/key IDs.
|
||||||
If no keys are found for a given server, key_id and source then
|
|
||||||
that server, key_id, and source triplet entry will be an empty list.
|
|
||||||
The JSON is returned as a byte array so that it can be efficiently
|
|
||||||
used in an HTTP response.
|
|
||||||
|
|
||||||
Args:
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
server_keys: List of (server_name, key_id, source) triplets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
|
||||||
"""
|
"""
|
||||||
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
def _get_server_keys_json_txn(
|
table="server_keys_json",
|
||||||
txn: LoggingTransaction,
|
column="key_id",
|
||||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
iterable=key_ids,
|
||||||
results = {}
|
keyvalues={"server_name": server_name},
|
||||||
for server_name, key_id, from_server in server_keys:
|
retcols=(
|
||||||
keyvalues = {"server_name": server_name}
|
"key_id",
|
||||||
if key_id is not None:
|
"from_server",
|
||||||
keyvalues["key_id"] = key_id
|
"ts_added_ms",
|
||||||
if from_server is not None:
|
"ts_valid_until_ms",
|
||||||
keyvalues["from_server"] = from_server
|
"key_json",
|
||||||
rows = self.db_pool.simple_select_list_txn(
|
),
|
||||||
txn,
|
desc="get_server_keys_json_for_remote",
|
||||||
"server_keys_json",
|
|
||||||
keyvalues=keyvalues,
|
|
||||||
retcols=(
|
|
||||||
"key_id",
|
|
||||||
"from_server",
|
|
||||||
"ts_added_ms",
|
|
||||||
"ts_valid_until_ms",
|
|
||||||
"key_json",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
results[(server_name, key_id, from_server)] = rows
|
|
||||||
return results
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
|
||||||
"get_server_keys_json", _get_server_keys_json_txn
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# We sort the rows so that the most recently added entry is picked up.
|
||||||
|
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
row["key_id"]: FetchKeyResultForRemote(
|
||||||
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
|
key_json=bytes(row["key_json"]),
|
||||||
|
valid_until_ts=row["ts_valid_until_ms"],
|
||||||
|
added_ts=row["ts_added_ms"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_all_server_keys_json_for_remote(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
) -> Dict[str, FetchKeyResultForRemote]:
|
||||||
|
"""Fetch the cached keys for the given server.
|
||||||
|
|
||||||
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
|
"""
|
||||||
|
rows = await self.db_pool.simple_select_list(
|
||||||
|
table="server_keys_json",
|
||||||
|
keyvalues={"server_name": server_name},
|
||||||
|
retcols=(
|
||||||
|
"key_id",
|
||||||
|
"from_server",
|
||||||
|
"ts_added_ms",
|
||||||
|
"ts_valid_until_ms",
|
||||||
|
"key_json",
|
||||||
|
),
|
||||||
|
desc="get_server_keys_json_for_remote",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
row["key_id"]: FetchKeyResultForRemote(
|
||||||
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
|
key_json=bytes(row["key_json"]),
|
||||||
|
valid_until_ts=row["ts_valid_until_ms"],
|
||||||
|
added_ts=row["ts_added_ms"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
|
@ -26,7 +26,6 @@ from synapse.storage.database import (
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
@ -96,6 +95,10 @@ class LockStore(SQLBaseStore):
|
||||||
|
|
||||||
self._acquiring_locks: Set[Tuple[str, str]] = set()
|
self._acquiring_locks: Set[Tuple[str, str]] = set()
|
||||||
|
|
||||||
|
self._clock.looping_call(
|
||||||
|
self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0
|
||||||
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("LockStore._on_shutdown")
|
@wrap_as_background_process("LockStore._on_shutdown")
|
||||||
async def _on_shutdown(self) -> None:
|
async def _on_shutdown(self) -> None:
|
||||||
"""Called when the server is shutting down"""
|
"""Called when the server is shutting down"""
|
||||||
|
@ -216,6 +219,7 @@ class LockStore(SQLBaseStore):
|
||||||
lock_name,
|
lock_name,
|
||||||
lock_key,
|
lock_key,
|
||||||
write,
|
write,
|
||||||
|
db_autocommit=True,
|
||||||
)
|
)
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
return None
|
return None
|
||||||
|
@ -233,61 +237,22 @@ class LockStore(SQLBaseStore):
|
||||||
# `worker_read_write_locks` and seeing if that fails any
|
# `worker_read_write_locks` and seeing if that fails any
|
||||||
# constraints. If it doesn't then we have acquired the lock,
|
# constraints. If it doesn't then we have acquired the lock,
|
||||||
# otherwise we haven't.
|
# otherwise we haven't.
|
||||||
#
|
|
||||||
# Before that though we clear the table of any stale locks.
|
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
token = random_string(6)
|
token = random_string(6)
|
||||||
|
|
||||||
delete_sql = """
|
self.db_pool.simple_insert_txn(
|
||||||
DELETE FROM worker_read_write_locks
|
txn,
|
||||||
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
|
table="worker_read_write_locks",
|
||||||
"""
|
values={
|
||||||
|
"lock_name": lock_name,
|
||||||
insert_sql = """
|
"lock_key": lock_key,
|
||||||
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
|
"write_lock": write,
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
"instance_name": self._instance_name,
|
||||||
"""
|
"token": token,
|
||||||
|
"last_renewed_ts": now,
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
},
|
||||||
# For Postgres we can send these queries at the same time.
|
)
|
||||||
txn.execute(
|
|
||||||
delete_sql + ";" + insert_sql,
|
|
||||||
(
|
|
||||||
# DELETE args
|
|
||||||
now - _LOCK_TIMEOUT_MS,
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
# UPSERT args
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
write,
|
|
||||||
self._instance_name,
|
|
||||||
token,
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For SQLite these need to be two queries.
|
|
||||||
txn.execute(
|
|
||||||
delete_sql,
|
|
||||||
(
|
|
||||||
now - _LOCK_TIMEOUT_MS,
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
insert_sql,
|
|
||||||
(
|
|
||||||
lock_name,
|
|
||||||
lock_key,
|
|
||||||
write,
|
|
||||||
self._instance_name,
|
|
||||||
token,
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
lock = Lock(
|
lock = Lock(
|
||||||
self._reactor,
|
self._reactor,
|
||||||
|
@ -351,6 +316,24 @@ class LockStore(SQLBaseStore):
|
||||||
|
|
||||||
return locks
|
return locks
|
||||||
|
|
||||||
|
@wrap_as_background_process("_reap_stale_read_write_locks")
|
||||||
|
async def _reap_stale_read_write_locks(self) -> None:
|
||||||
|
delete_sql = """
|
||||||
|
DELETE FROM worker_read_write_locks
|
||||||
|
WHERE last_renewed_ts < ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,))
|
||||||
|
if txn.rowcount:
|
||||||
|
logger.info("Reaped %d stale locks", txn.rowcount)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"_reap_stale_read_write_locks",
|
||||||
|
reap_stale_read_write_locks_txn,
|
||||||
|
db_autocommit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Lock:
|
class Lock:
|
||||||
"""An async context manager that manages an acquired lock, ensuring it is
|
"""An async context manager that manages an acquired lock, ensuring it is
|
||||||
|
|
|
@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
||||||
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||||
COALESCE(approved, TRUE) AS approved
|
COALESCE(approved, TRUE) AS approved,
|
||||||
|
COALESCE(locked, FALSE) AS locked
|
||||||
FROM users
|
FROM users
|
||||||
WHERE name = ?
|
WHERE name = ?
|
||||||
""",
|
""",
|
||||||
|
@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
# want to make sure we're returning the right type of data.
|
# want to make sure we're returning the right type of data.
|
||||||
# Note: when adding a column name to this list, be wary of NULLable columns,
|
# Note: when adding a column name to this list, be wary of NULLable columns,
|
||||||
# since NULL values will be turned into False.
|
# since NULL values will be turned into False.
|
||||||
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
|
boolean_columns = [
|
||||||
|
"admin",
|
||||||
|
"deactivated",
|
||||||
|
"shadow_banned",
|
||||||
|
"approved",
|
||||||
|
"locked",
|
||||||
|
]
|
||||||
for column in boolean_columns:
|
for column in boolean_columns:
|
||||||
if not isinstance(row[column], bool):
|
row[column] = bool(row[column])
|
||||||
row[column] = bool(row[column])
|
|
||||||
|
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
# Convert the integer into a boolean.
|
# Convert the integer into a boolean.
|
||||||
return res == 1
|
return res == 1
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def get_user_locked_status(self, user_id: str) -> bool:
|
||||||
|
"""Retrieve the value for the `locked` property for the provided user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user to retrieve the status for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the user was locked, false if the user is still active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
retcol="locked",
|
||||||
|
desc="get_user_locked_status",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the potential integer into a boolean.
|
||||||
|
return bool(res)
|
||||||
|
|
||||||
async def get_threepid_validation_session(
|
async def get_threepid_validation_session(
|
||||||
self,
|
self,
|
||||||
medium: Optional[str],
|
medium: Optional[str],
|
||||||
|
@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
|
async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
|
||||||
|
"""Set the `locked` property for the provided user to the provided value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user to set the status for.
|
||||||
|
locked: The value to set for `locked`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"set_user_locked_status",
|
||||||
|
self.set_user_locked_status_txn,
|
||||||
|
user_id,
|
||||||
|
locked,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_user_locked_status_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, locked: bool
|
||||||
|
) -> None:
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
updatevalues={"locked": locked},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
|
||||||
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
|
|
||||||
def update_user_approval_status_txn(
|
def update_user_approval_status_txn(
|
||||||
self, txn: LoggingTransaction, user_id: str, approved: bool
|
self, txn: LoggingTransaction, user_id: str, approved: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -19,6 +19,7 @@ from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Counter,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -28,8 +29,6 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Counter
|
|
||||||
|
|
||||||
from twisted.internet.defer import DeferredLock
|
from twisted.internet.defer import DeferredLock
|
||||||
|
|
||||||
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
|
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
|
||||||
|
|
|
@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search_user_dir(
|
async def search_user_dir(
|
||||||
self, user_id: str, search_term: str, limit: int
|
self,
|
||||||
|
user_id: str,
|
||||||
|
search_term: str,
|
||||||
|
limit: int,
|
||||||
|
show_locked_users: bool = False,
|
||||||
) -> SearchResult:
|
) -> SearchResult:
|
||||||
"""Searches for users in directory
|
"""Searches for users in directory
|
||||||
|
|
||||||
|
@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not show_locked_users:
|
||||||
|
where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
|
||||||
|
|
||||||
# We allow manipulating the ranking algorithm by injecting statements
|
# We allow manipulating the ranking algorithm by injecting statements
|
||||||
# based on config options.
|
# based on config options.
|
||||||
additional_ordering_statements = []
|
additional_ordering_statements = []
|
||||||
|
@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
SELECT d.user_id AS user_id, display_name, avatar_url
|
SELECT d.user_id AS user_id, display_name, avatar_url
|
||||||
FROM matching_users as t
|
FROM matching_users as t
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
|
LEFT JOIN users AS u ON t.user_id = u.name
|
||||||
WHERE
|
WHERE
|
||||||
%(where_clause)s
|
%(where_clause)s
|
||||||
ORDER BY
|
ORDER BY
|
||||||
|
@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
SELECT d.user_id AS user_id, display_name, avatar_url
|
SELECT d.user_id AS user_id, display_name, avatar_url
|
||||||
FROM user_directory_search as t
|
FROM user_directory_search as t
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
|
LEFT JOIN users AS u ON t.user_id = u.name
|
||||||
WHERE
|
WHERE
|
||||||
%(where_clause)s
|
%(where_clause)s
|
||||||
AND value MATCH ?
|
AND value MATCH ?
|
||||||
|
|
|
@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
|
||||||
|
|
||||||
This is not provided by DBAPI2, and so needs engine-specific support.
|
This is not provided by DBAPI2, and so needs engine-specific support.
|
||||||
"""
|
"""
|
||||||
with open(filepath, "rt") as f:
|
with open(filepath) as f:
|
||||||
cls.executescript(cursor, f.read())
|
cls.executescript(cursor, f.read())
|
||||||
|
|
|
@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
|
||||||
class FetchKeyResult:
|
class FetchKeyResult:
|
||||||
verify_key: VerifyKey # the key itself
|
verify_key: VerifyKey # the key itself
|
||||||
valid_until_ts: int # how long we can use this key for
|
valid_until_ts: int # how long we can use this key for
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class FetchKeyResultForRemote:
|
||||||
|
key_json: bytes # the full key JSON
|
||||||
|
valid_until_ts: int # how long we can use this key for, in milliseconds.
|
||||||
|
added_ts: int # When we added this key, in milliseconds.
|
||||||
|
|
|
@ -16,10 +16,18 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple
|
from typing import (
|
||||||
|
Collection,
|
||||||
|
Counter as CounterType,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TextIO,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Counter as CounterType
|
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction
|
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
/* Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL;
|
|
@ -21,6 +21,7 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
|
Final,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Match,
|
Match,
|
||||||
|
@ -38,7 +39,7 @@ import attr
|
||||||
from immutabledict import immutabledict
|
from immutabledict import immutabledict
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.types import VerifyKey
|
from signedjson.types import VerifyKey
|
||||||
from typing_extensions import Final, TypedDict
|
from typing_extensions import TypedDict
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
from zope.interface import Interface
|
from zope.interface import Interface
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -42,7 +43,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
|
from typing_extensions import Concatenate, Literal, ParamSpec
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
|
|
|
@ -218,7 +218,7 @@ class MacaroonGenerator:
|
||||||
# to avoid validating those as guest tokens, we explicitely verify if
|
# to avoid validating those as guest tokens, we explicitely verify if
|
||||||
# the macaroon includes the "guest = true" caveat.
|
# the macaroon includes the "guest = true" caveat.
|
||||||
is_guest = any(
|
is_guest = any(
|
||||||
(caveat.caveat_id == "guest = true" for caveat in macaroon.caveats)
|
caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_guest:
|
if not is_guest:
|
||||||
|
|
|
@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
|
||||||
SynapseManhole, dict(globals, __name__="__console__")
|
SynapseManhole, dict(globals, __name__="__console__")
|
||||||
)
|
)
|
||||||
|
|
||||||
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
# type-ignore: This is an error in Twisted's annotations. See
|
||||||
|
# https://github.com/twisted/twisted/issues/11812 and /11813 .
|
||||||
|
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type]
|
||||||
|
|
||||||
# conch has the wrong type on these dicts (says bytes to bytes,
|
# conch has the wrong type on these dicts (says bytes to bytes,
|
||||||
# should be bytes to Keys judging by how it's used).
|
# should be bytes to Keys judging by how it's used).
|
||||||
|
|
|
@ -20,6 +20,7 @@ import typing
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
ContextManager,
|
||||||
DefaultDict,
|
DefaultDict,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
|
@ -33,7 +34,6 @@ from typing import (
|
||||||
from weakref import WeakSet
|
from weakref import WeakSet
|
||||||
|
|
||||||
from prometheus_client.core import Counter
|
from prometheus_client.core import Counter
|
||||||
from typing_extensions import ContextManager
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from enum import Enum, auto
|
||||||
from typing import (
|
from typing import (
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
Final,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
@ -27,7 +28,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Final
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
|
|
@ -69,6 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = simple_async_mock(None)
|
||||||
|
self.store.get_user_locked_status = simple_async_mock(False)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -293,6 +294,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = simple_async_mock(None)
|
||||||
|
self.store.get_user_locked_status = simple_async_mock(False)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -311,6 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
token_used=True,
|
token_used=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.store.get_user_locked_status = simple_async_mock(False)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = simple_async_mock(None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
|
|
@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
|
||||||
def make_homeserver(
|
def make_homeserver(
|
||||||
self, reactor: ThreadedMemoryReactorClock, clock: Clock
|
self, reactor: ThreadedMemoryReactorClock, clock: Clock
|
||||||
) -> HomeServer:
|
) -> HomeServer:
|
||||||
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
|
hs = super().make_homeserver(reactor, clock)
|
||||||
|
|
||||||
# We don't want our tests to actually report statistics, so check
|
# We don't want our tests to actually report statistics, so check
|
||||||
# that it's not enabled
|
# that it's not enabled
|
||||||
|
|
|
@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
[("server9", get_key_id(key1))]
|
[("server9", get_key_id(key1))]
|
||||||
)
|
)
|
||||||
result = self.get_success(d)
|
result = self.get_success(d)
|
||||||
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
|
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
|
||||||
|
|
||||||
def test_verify_json_dedupes_key_requests(self) -> None:
|
def test_verify_json_dedupes_key_requests(self) -> None:
|
||||||
"""Two requests for the same key should be deduped."""
|
"""Two requests for the same key should be deduped."""
|
||||||
|
@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], SERVER_NAME)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
# we expect it to be encoded as canonical json *before* it hits the db
|
# we expect it to be encoded as canonical json *before* it hits the db
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# change the server name: the result should be ignored
|
# change the server name: the result should be ignored
|
||||||
response["server_name"] = "OTHER_SERVER"
|
response["server_name"] = "OTHER_SERVER"
|
||||||
|
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_multiple_keys_from_perspectives(self) -> None:
|
def test_get_multiple_keys_from_perspectives(self) -> None:
|
||||||
"""Check that we can correctly request multiple keys for the same server"""
|
"""Check that we can correctly request multiple keys for the same server"""
|
||||||
|
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_invalid_perspectives_responses(self) -> None:
|
def test_invalid_perspectives_responses(self) -> None:
|
||||||
"""Check that invalid responses from the perspectives server are rejected"""
|
"""Check that invalid responses from the perspectives server are rejected"""
|
||||||
|
|
|
@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(res["events"]), 1)
|
self.assertEqual(len(res["events"]), 1)
|
||||||
self.assertEqual(res["events"][0]["content"]["body"], "foo")
|
self.assertEqual(res["events"][0]["content"]["body"], "foo")
|
||||||
|
|
||||||
# Fetch the message of the dehydrated device again, which should return nothing
|
# Fetch the message of the dehydrated device again, which should return
|
||||||
# and delete the old messages
|
# the same message as it has not been deleted
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.message_handler.get_events_for_dehydrated_device(
|
self.message_handler.get_events_for_dehydrated_device(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
device_id=stored_dehydrated_device_id,
|
device_id=stored_dehydrated_device_id,
|
||||||
since_token=res["next_batch"],
|
since_token=None,
|
||||||
limit=10,
|
limit=10,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertTrue(len(res["next_batch"]) > 1)
|
self.assertTrue(len(res["next_batch"]) > 1)
|
||||||
self.assertEqual(len(res["events"]), 0)
|
self.assertEqual(len(res["events"]), 1)
|
||||||
|
self.assertEqual(res["events"][0]["content"]["body"], "foo")
|
||||||
|
|
|
@ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||||
self.assertEqual(error.value.code, 503)
|
self.assertEqual(error.value.code, 503)
|
||||||
|
|
||||||
|
def test_introspection_token_cache(self) -> None:
|
||||||
|
access_token = "open_sesame"
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse.json(
|
||||||
|
code=200,
|
||||||
|
payload={"active": "true", "scope": "guest", "jti": access_token},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# first call should cache response
|
||||||
|
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
|
||||||
|
# for regular auth code via the config
|
||||||
|
self.get_success(
|
||||||
|
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(introspection_token["jti"], access_token)
|
||||||
|
# there's been one http request
|
||||||
|
self.http_client.request.assert_called_once()
|
||||||
|
|
||||||
|
# second call should pull from cache, there should still be only one http request
|
||||||
|
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
|
||||||
|
self.http_client.request.assert_called_once()
|
||||||
|
self.assertEqual(token["jti"], access_token)
|
||||||
|
|
||||||
|
# advance past five minutes and check that cache expired - there should be more than one http call now
|
||||||
|
self.reactor.advance(360)
|
||||||
|
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 2)
|
||||||
|
self.assertEqual(token_2["jti"], access_token)
|
||||||
|
|
||||||
|
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
|
||||||
|
# token with a soon-to-expire `exp` field to the cache
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse.json(
|
||||||
|
code=200,
|
||||||
|
payload={
|
||||||
|
"active": "true",
|
||||||
|
"scope": "guest",
|
||||||
|
"jti": "stale",
|
||||||
|
"exp": self.clock.time() + 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.auth._introspect_token("stale") # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
|
||||||
|
self.assertEqual(introspection_token["jti"], "stale")
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 1)
|
||||||
|
|
||||||
|
# advance the reactor past the token expiry but less than the cache expiry
|
||||||
|
self.reactor.advance(120)
|
||||||
|
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# check that the next call causes another http request (which will fail because the token is technically expired
|
||||||
|
# but the important thing is we discard the token from the cache and try the network)
|
||||||
|
self.get_failure(
|
||||||
|
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
self.assertEqual(self.http_client.request.call_count, 2)
|
||||||
|
|
||||||
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
|
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
|
||||||
# We only generate a master key to simplify the test.
|
# We only generate a master key to simplify the test.
|
||||||
master_signing_key = generate_signing_key(device_id)
|
master_signing_key = generate_signing_key(device_id)
|
||||||
|
|
|
@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
self.assertEqual(response.code, 200)
|
self.assertEqual(response.code, 200)
|
||||||
|
|
||||||
# Send the body
|
# Send the body
|
||||||
request.write('{ "a": 1 }'.encode("ascii"))
|
request.write(b'{ "a": 1 }')
|
||||||
request.finish()
|
request.finish()
|
||||||
|
|
||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
|
|
|
@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
self.assertEqual(channel.json_body["creator"], user_id)
|
self.assertEqual(channel.json_body["creator"], user_id)
|
||||||
|
|
||||||
# Check room alias.
|
# Check room alias.
|
||||||
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
|
self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
|
||||||
|
|
||||||
# Let's try a room with no alias.
|
# Let's try a room with no alias.
|
||||||
room_id, room_alias = self.get_success(
|
room_id, room_alias = self.get_success(
|
||||||
|
|
|
@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
self.assertEqual(request.method, b"GET")
|
self.assertEqual(request.method, b"GET")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.path,
|
request.path,
|
||||||
f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
|
f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
|
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
|
||||||
|
|
|
@ -29,7 +29,16 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
|
||||||
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.media.filepath import MediaFilePaths
|
from synapse.media.filepath import MediaFilePaths
|
||||||
from synapse.rest.client import devices, login, logout, profile, register, room, sync
|
from synapse.rest.client import (
|
||||||
|
devices,
|
||||||
|
login,
|
||||||
|
logout,
|
||||||
|
profile,
|
||||||
|
register,
|
||||||
|
room,
|
||||||
|
sync,
|
||||||
|
user_directory,
|
||||||
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -1477,6 +1486,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
|
user_directory.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
@ -2464,6 +2474,105 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
# This key was removed intentionally. Ensure it is not accidentally re-included.
|
# This key was removed intentionally. Ensure it is not accidentally re-included.
|
||||||
self.assertNotIn("password_hash", channel.json_body)
|
self.assertNotIn("password_hash", channel.json_body)
|
||||||
|
|
||||||
|
def test_locked_user(self) -> None:
|
||||||
|
# User can sync
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/client/v3/sync",
|
||||||
|
access_token=self.other_user_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
|
||||||
|
# Lock user
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"locked": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is not authorized to sync anymore
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/client/v3/sync",
|
||||||
|
access_token=self.other_user_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(401, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
|
||||||
|
self.assertTrue(channel.json_body["soft_logout"])
|
||||||
|
|
||||||
|
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
||||||
|
def test_locked_user_not_in_user_dir(self) -> None:
|
||||||
|
# User is available in the user dir
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/v3/user_directory/search",
|
||||||
|
{"search_term": self.other_user},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIn("results", channel.json_body)
|
||||||
|
self.assertEqual(1, len(channel.json_body["results"]))
|
||||||
|
|
||||||
|
# Lock user
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"locked": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is not available anymore in the user dir
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/v3/user_directory/search",
|
||||||
|
{"search_term": self.other_user},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIn("results", channel.json_body)
|
||||||
|
self.assertEqual(0, len(channel.json_body["results"]))
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"user_directory": {
|
||||||
|
"enabled": True,
|
||||||
|
"search_all_users": True,
|
||||||
|
"show_locked_users": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
|
||||||
|
# User is available in the user dir
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/v3/user_directory/search",
|
||||||
|
{"search_term": self.other_user},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIn("results", channel.json_body)
|
||||||
|
self.assertEqual(1, len(channel.json_body["results"]))
|
||||||
|
|
||||||
|
# Lock user
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"locked": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is still available in the user dir
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/v3/user_directory/search",
|
||||||
|
{"search_term": self.other_user},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIn("results", channel.json_body)
|
||||||
|
self.assertEqual(1, len(channel.json_body["results"]))
|
||||||
|
|
||||||
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
@override_config({"user_directory": {"enabled": True, "search_all_users": True}})
|
||||||
def test_change_name_deactivate_user_user_directory(self) -> None:
|
def test_change_name_deactivate_user_user_directory(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
|
||||||
from synapse.rest import admin, devices, room, sync
|
from synapse.rest import admin, devices, room, sync
|
||||||
from synapse.rest.client import account, keys, login, register
|
from synapse.rest.client import account, keys, login, register
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, create_requester
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
|
||||||
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
|
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"fallback_keys": {
|
||||||
|
"alg1:device1": "f4llb4ckk3y",
|
||||||
|
"signed_<algorithm>:<device_id>": {
|
||||||
|
"fallback": "true",
|
||||||
|
"key": "f4llb4ckk3y",
|
||||||
|
"signatures": {
|
||||||
|
"<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"one_time_keys": {"alg1:k1": "0net1m3k3y"},
|
||||||
}
|
}
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
|
@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.assertEqual(device_data, expected_device_data)
|
self.assertEqual(device_data, expected_device_data)
|
||||||
|
|
||||||
|
# test that the keys are correctly uploaded
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/keys/query",
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
user: ["device1"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["device_keys"][user][device_id]["keys"],
|
||||||
|
content["device_keys"]["keys"],
|
||||||
|
)
|
||||||
|
# first claim should return the onetime key we uploaded
|
||||||
|
res = self.get_success(
|
||||||
|
self.hs.get_e2e_keys_handler().claim_one_time_keys(
|
||||||
|
{user: {device_id: {"alg1": 1}}},
|
||||||
|
UserID.from_string(user),
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# second claim should return fallback key
|
||||||
|
res2 = self.get_success(
|
||||||
|
self.hs.get_e2e_keys_handler().claim_one_time_keys(
|
||||||
|
{user: {device_id: {"alg1": 1}}},
|
||||||
|
UserID.from_string(user),
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res2,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# create another device for the user
|
# create another device for the user
|
||||||
(
|
(
|
||||||
new_device_id,
|
new_device_id,
|
||||||
|
@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
expected_content = {"body": "test_message"}
|
expected_content = {"body": "test_message"}
|
||||||
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
|
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
|
||||||
|
|
||||||
|
# fetch messages again and make sure that the message was not deleted
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
|
||||||
|
content={},
|
||||||
|
access_token=token,
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
|
||||||
next_batch_token = channel.json_body.get("next_batch")
|
next_batch_token = channel.json_body.get("next_batch")
|
||||||
|
|
||||||
# fetch messages again and make sure that the message was deleted and we are returned an
|
# make sure fetching messages with next batch token works - there are no unfetched
|
||||||
# empty array
|
# messages so we should receive an empty array
|
||||||
content = {"next_batch": next_batch_token}
|
content = {"next_batch": next_batch_token}
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
|
|
|
@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
|
||||||
redact_event = timeline[-1]
|
redact_event = timeline[-1]
|
||||||
self.assertEqual(redact_event["type"], EventTypes.Redaction)
|
self.assertEqual(redact_event["type"], EventTypes.Redaction)
|
||||||
# The redacts key should be in the content and the redacts keys.
|
# The redacts key should be in the content and the redacts keys.
|
||||||
self.assertEquals(redact_event["content"]["redacts"], event_id)
|
self.assertEqual(redact_event["content"]["redacts"], event_id)
|
||||||
self.assertEquals(redact_event["redacts"], event_id)
|
self.assertEqual(redact_event["redacts"], event_id)
|
||||||
|
|
||||||
# But it isn't actually part of the event.
|
# But it isn't actually part of the event.
|
||||||
def get_event(txn: LoggingTransaction) -> JsonDict:
|
def get_event(txn: LoggingTransaction) -> JsonDict:
|
||||||
|
@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
|
||||||
event_json = self.get_success(
|
event_json = self.get_success(
|
||||||
main_datastore.db_pool.runInteraction("get_event", get_event)
|
main_datastore.db_pool.runInteraction("get_event", get_event)
|
||||||
)
|
)
|
||||||
self.assertEquals(event_json["type"], EventTypes.Redaction)
|
self.assertEqual(event_json["type"], EventTypes.Redaction)
|
||||||
if expect_content:
|
if expect_content:
|
||||||
self.assertNotIn("redacts", event_json)
|
self.assertNotIn("redacts", event_json)
|
||||||
self.assertEquals(event_json["content"]["redacts"], event_id)
|
self.assertEqual(event_json["content"]["redacts"], event_id)
|
||||||
else:
|
else:
|
||||||
self.assertEquals(event_json["redacts"], event_id)
|
self.assertEqual(event_json["redacts"], event_id)
|
||||||
self.assertNotIn("redacts", event_json["content"])
|
self.assertNotIn("redacts", event_json["content"])
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue