Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
d1b33ae118
|
@ -129,7 +129,7 @@ body:
|
|||
attributes:
|
||||
label: Relevant log output
|
||||
description: |
|
||||
Please copy and paste any relevant log output, ideally at INFO or DEBUG log level.
|
||||
Please copy and paste any relevant log output as text (not images), ideally at INFO or DEBUG log level.
|
||||
This will be automatically formatted into code, so there is no need for backticks (`\``).
|
||||
|
||||
Please be careful to remove any personal or private data.
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
Synapse 1.82.0 (2023-04-25)
|
||||
===========================
|
||||
|
||||
No significant changes since 1.82.0rc1.
|
||||
|
||||
|
||||
Synapse 1.82.0rc1 (2023-04-18)
|
||||
==============================
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Speedup tests by caching HomeServerConfig instances.
|
|
@ -0,0 +1 @@
|
|||
Experimental support for MSC3970: Scope transaction IDs to devices.
|
|
@ -0,0 +1 @@
|
|||
Add denormalised event stream ordering column to membership state tables for future use. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used.
|
|
@ -0,0 +1 @@
|
|||
Always use multi-user device resync replication endpoints.
|
|
@ -0,0 +1 @@
|
|||
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
|
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
|
@ -0,0 +1 @@
|
|||
Update the check_schema_delta script to account for when the schema version has been bumped locally.
|
|
@ -0,0 +1 @@
|
|||
Bump types-pyyaml from 6.0.12.8 to 6.0.12.9.
|
|
@ -0,0 +1 @@
|
|||
Bump pyasn1-modules from 0.2.8 to 0.3.0.
|
|
@ -0,0 +1 @@
|
|||
Bump cryptography from 40.0.1 to 40.0.2.
|
|
@ -0,0 +1 @@
|
|||
Bump types-netaddr from 0.8.0.7 to 0.8.0.8.
|
|
@ -0,0 +1 @@
|
|||
Bump types-jsonschema from 4.17.0.6 to 4.17.0.7.
|
|
@ -0,0 +1 @@
|
|||
Ask bug reporters to provide logs as text.
|
|
@ -1,3 +1,9 @@
|
|||
matrix-synapse-py3 (1.82.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.82.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 25 Apr 2023 11:56:06 +0100
|
||||
|
||||
matrix-synapse-py3 (1.82.0~rc1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.82.0rc1.
|
||||
|
|
6
mypy.ini
6
mypy.ini
|
@ -33,12 +33,6 @@ exclude = (?x)
|
|||
|synapse/storage/schema/
|
||||
)$
|
||||
|
||||
[mypy-synapse.federation.transport.client]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-synapse.http.matrixfederationclient]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-synapse.metrics._reactor_metrics]
|
||||
disallow_untyped_defs = False
|
||||
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
|
||||
|
|
|
@ -481,31 +481,31 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "40.0.1"
|
||||
version = "40.0.2"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:918cb89086c7d98b1b86b9fdb70c712e5a9325ba6f7d7cfb509e784e0cfc6917"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9618a87212cb5200500e304e43691111570e1f10ec3f35569fdfcd17e28fd797"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4805a4ca729d65570a1b7cac84eac1e431085d40387b7d3bbaa47e39890b88"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63dac2d25c47f12a7b8aa60e528bfb3c51c5a6c5a9f7c86987909c6c79765554"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:0a4e3406cfed6b1f6d6e87ed243363652b2586b2d917b0609ca4f97072994405"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1e0af458515d5e4028aad75f3bb3fe7a31e46ad920648cd59b64d3da842e4356"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d8aa3609d337ad85e4eb9bb0f8bcf6e4409bfb86e706efa9a027912169e89122"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cf91e428c51ef692b82ce786583e214f58392399cf65c341bc7301d096fa3ba2"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-win32.whl", hash = "sha256:650883cc064297ef3676b1db1b7b1df6081794c4ada96fa457253c4cc40f97db"},
|
||||
{file = "cryptography-40.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:a805a7bce4a77d51696410005b3e85ae2839bad9aa38894afc0aa99d8e0c3160"},
|
||||
{file = "cryptography-40.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cd033d74067d8928ef00a6b1327c8ea0452523967ca4463666eeba65ca350d4c"},
|
||||
{file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d36bbeb99704aabefdca5aee4eba04455d7a27ceabd16f3b3ba9bdcc31da86c4"},
|
||||
{file = "cryptography-40.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:32057d3d0ab7d4453778367ca43e99ddb711770477c4f072a51b3ca69602780a"},
|
||||
{file = "cryptography-40.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:f5d7b79fa56bc29580faafc2ff736ce05ba31feaa9d4735048b0de7d9ceb2b94"},
|
||||
{file = "cryptography-40.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7c872413353c70e0263a9368c4993710070e70ab3e5318d85510cc91cce77e7c"},
|
||||
{file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:28d63d75bf7ae4045b10de5413fb1d6338616e79015999ad9cf6fc538f772d41"},
|
||||
{file = "cryptography-40.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6f2bbd72f717ce33100e6467572abaedc61f1acb87b8d546001328d7f466b778"},
|
||||
{file = "cryptography-40.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cc3a621076d824d75ab1e1e530e66e7e8564e357dd723f2533225d40fe35c60c"},
|
||||
{file = "cryptography-40.0.1.tar.gz", hash = "sha256:2803f2f8b1e95f614419926c7e6f55d828afc614ca5ed61543877ae668cc3472"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:8f79b5ff5ad9d3218afb1e7e20ea74da5f76943ee5edb7f76e56ec5161ec782b"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:05dc219433b14046c476f6f09d7636b92a1c3e5808b9a6536adf4932b3b2c440"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4df2af28d7bedc84fe45bd49bc35d710aede676e2a4cb7fc6d103a2adc8afe4d"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dcca15d3a19a66e63662dc8d30f8036b07be851a8680eda92d079868f106288"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:a04386fb7bc85fab9cd51b6308633a3c271e3d0d3eae917eebab2fac6219b6d2"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:adc0d980fd2760c9e5de537c28935cc32b9353baaf28e0814df417619c6c8c3b"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d5a1bd0e9e2031465761dfa920c16b0065ad77321d8a8c1f5ee331021fda65e9"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:a95f4802d49faa6a674242e25bfeea6fc2acd915b5e5e29ac90a32b1139cae1c"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-win32.whl", hash = "sha256:aecbb1592b0188e030cb01f82d12556cf72e218280f621deed7d806afd2113f9"},
|
||||
{file = "cryptography-40.0.2-cp36-abi3-win_amd64.whl", hash = "sha256:b12794f01d4cacfbd3177b9042198f3af1c856eedd0a98f10f141385c809a14b"},
|
||||
{file = "cryptography-40.0.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:142bae539ef28a1c76794cca7f49729e7c54423f615cfd9b0b1fa90ebe53244b"},
|
||||
{file = "cryptography-40.0.2-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:956ba8701b4ffe91ba59665ed170a2ebbdc6fc0e40de5f6059195d9f2b33ca0e"},
|
||||
{file = "cryptography-40.0.2-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4f01c9863da784558165f5d4d916093737a75203a5c5286fde60e503e4276c7a"},
|
||||
{file = "cryptography-40.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:3daf9b114213f8ba460b829a02896789751626a2a4e7a43a28ee77c04b5e4958"},
|
||||
{file = "cryptography-40.0.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48f388d0d153350f378c7f7b41497a54ff1513c816bcbbcafe5b829e59b9ce5b"},
|
||||
{file = "cryptography-40.0.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c0764e72b36a3dc065c155e5b22f93df465da9c39af65516fe04ed3c68c92636"},
|
||||
{file = "cryptography-40.0.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:cbaba590180cba88cb99a5f76f90808a624f18b169b90a4abb40c1fd8c19420e"},
|
||||
{file = "cryptography-40.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7a38250f433cd41df7fcb763caa3ee9362777fdb4dc642b9a349721d2bf47404"},
|
||||
{file = "cryptography-40.0.2.tar.gz", hash = "sha256:c33c0d32b8594fa647d2e01dbccc303478e16fdd7cf98652d5b3ed11aa5e5c99"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1860,18 +1860,18 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pyasn1-modules"
|
||||
version = "0.2.8"
|
||||
description = "A collection of ASN.1-based protocols modules."
|
||||
version = "0.3.0"
|
||||
description = "A collection of ASN.1-based protocols modules"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||
files = [
|
||||
{file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"},
|
||||
{file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"},
|
||||
{file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"},
|
||||
{file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyasn1 = ">=0.4.6,<0.5.0"
|
||||
pyasn1 = ">=0.4.6,<0.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
|
@ -3022,26 +3022,26 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "types-jsonschema"
|
||||
version = "4.17.0.6"
|
||||
version = "4.17.0.7"
|
||||
description = "Typing stubs for jsonschema"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-jsonschema-4.17.0.6.tar.gz", hash = "sha256:e9b15e34b4f2fd5587bd68530fa0eb2a17c73ead212f4471d71eea032d231c46"},
|
||||
{file = "types_jsonschema-4.17.0.6-py3-none-any.whl", hash = "sha256:ecef99bc64848f3798ad18922dfb2b40da25f17796fafcee50da984a21c5d6e6"},
|
||||
{file = "types-jsonschema-4.17.0.7.tar.gz", hash = "sha256:130e57c5f1ca755f95775d0822ad7a3907294e1461306af54baf804f317fd54c"},
|
||||
{file = "types_jsonschema-4.17.0.7-py3-none-any.whl", hash = "sha256:e129b52be6df841d97a98f087631dd558f7812eb91ff7b733c3301bd2446271b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-netaddr"
|
||||
version = "0.8.0.7"
|
||||
version = "0.8.0.8"
|
||||
description = "Typing stubs for netaddr"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-netaddr-0.8.0.7.tar.gz", hash = "sha256:3362864fa0258782d449b91707f37e55f62290b4f438974a08758b498169e109"},
|
||||
{file = "types_netaddr-0.8.0.7-py3-none-any.whl", hash = "sha256:a540cdfb2f858a0509ce5a4e4fcc80ef11b19f10a2473e48d32217af517818c0"},
|
||||
{file = "types-netaddr-0.8.0.8.tar.gz", hash = "sha256:db7e8cd16b1244e7c4541edd0df99d1039fc05fd5387c21840f0b958fc52aabc"},
|
||||
{file = "types_netaddr-0.8.0.8-py3-none-any.whl", hash = "sha256:6741b3824e2ec3f7a74842b394439b71107c7675f8ae42bb2b5e7a8ebfe8cf18"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3097,14 +3097,14 @@ cryptography = ">=35.0.0"
|
|||
|
||||
[[package]]
|
||||
name = "types-pyyaml"
|
||||
version = "6.0.12.8"
|
||||
version = "6.0.12.9"
|
||||
description = "Typing stubs for PyYAML"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-PyYAML-6.0.12.8.tar.gz", hash = "sha256:19304869a89d49af00be681e7b267414df213f4eb89634c4495fa62e8f942b9f"},
|
||||
{file = "types_PyYAML-6.0.12.8-py3-none-any.whl", hash = "sha256:5314a4b2580999b2ea06b2e5f9a7763d860d6e09cdf21c0e9561daa9cbd60178"},
|
||||
{file = "types-PyYAML-6.0.12.9.tar.gz", hash = "sha256:c51b1bd6d99ddf0aa2884a7a328810ebf70a4262c292195d3f4f9a0005f9eeb6"},
|
||||
{file = "types_PyYAML-6.0.12.9-py3-none-any.whl", hash = "sha256:5aed5aa66bd2d2e158f75dda22b059570ede988559f030cf294871d3b647e3e8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
|
|||
|
||||
[tool.poetry]
|
||||
name = "matrix-synapse"
|
||||
version = "1.82.0rc1"
|
||||
version = "1.82.0"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||
license = "Apache-2.0"
|
||||
|
|
|
@ -40,10 +40,32 @@ def main(force_colors: bool) -> None:
|
|||
exec(r, locals)
|
||||
current_schema_version = locals["SCHEMA_VERSION"]
|
||||
|
||||
click.secho(f"Current schema version: {current_schema_version}")
|
||||
|
||||
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
|
||||
|
||||
# Get the schema version of the local file to check against current schema on develop
|
||||
with open("synapse/storage/schema/__init__.py", "r") as file:
|
||||
local_schema = file.read()
|
||||
new_locals: Dict[str, Any] = {}
|
||||
exec(local_schema, new_locals)
|
||||
local_schema_version = new_locals["SCHEMA_VERSION"]
|
||||
|
||||
if local_schema_version != current_schema_version:
|
||||
# local schema version must be +/-1 the current schema version on develop
|
||||
if abs(local_schema_version - current_schema_version) != 1:
|
||||
click.secho(
|
||||
"The proposed schema version has diverged more than one version from develop, please fix!",
|
||||
fg="red",
|
||||
bold=True,
|
||||
color=force_colors,
|
||||
)
|
||||
click.get_current_context().exit(1)
|
||||
|
||||
# right, we've changed the schema version within the allowable tolerance so
|
||||
# let's now use the local version as the canonical version
|
||||
current_schema_version = local_schema_version
|
||||
|
||||
click.secho(f"Current schema version: {current_schema_version}")
|
||||
|
||||
seen_deltas = False
|
||||
bad_files = []
|
||||
for diff in diffs:
|
||||
|
|
|
@ -191,3 +191,6 @@ class ExperimentalConfig(Config):
|
|||
|
||||
# MSC2659: Application service ping endpoint
|
||||
self.msc2659_enabled = experimental.get("msc2659_enabled", False)
|
||||
|
||||
# MSC3970: Scope transaction IDs to devices
|
||||
self.msc3970_enabled = experimental.get("msc3970_enabled", False)
|
||||
|
|
|
@ -150,18 +150,19 @@ class Keyring:
|
|||
def __init__(
|
||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||
):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
if key_fetchers is None:
|
||||
key_fetchers = (
|
||||
# Fetch keys from the database.
|
||||
StoreKeyFetcher(hs),
|
||||
# Fetch keys from a configured Perspectives server.
|
||||
PerspectivesKeyFetcher(hs),
|
||||
# Fetch keys from the origin server directly.
|
||||
ServerKeyFetcher(hs),
|
||||
)
|
||||
self._key_fetchers = key_fetchers
|
||||
# Always fetch keys from the database.
|
||||
mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
|
||||
# Fetch keys from configured trusted key servers, if any exist.
|
||||
key_servers = hs.config.key.key_servers
|
||||
if key_servers:
|
||||
mutable_key_fetchers.append(PerspectivesKeyFetcher(hs))
|
||||
# Finally, fetch keys from the origin server directly.
|
||||
mutable_key_fetchers.append(ServerKeyFetcher(hs))
|
||||
|
||||
self._key_fetchers: Iterable[KeyFetcher] = tuple(mutable_key_fetchers)
|
||||
else:
|
||||
self._key_fetchers = key_fetchers
|
||||
|
||||
self._fetch_keys_queue: BatchingQueue[
|
||||
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
|
||||
|
@ -510,7 +511,7 @@ class StoreKeyFetcher(KeyFetcher):
|
|||
for key_id in queue_value.key_ids
|
||||
)
|
||||
|
||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||
res = await self.store.get_server_keys_json(key_ids_to_fetch)
|
||||
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
|
||||
for (server_name, key_id), key in res.items():
|
||||
keys.setdefault(server_name, {})[key_id] = key
|
||||
|
@ -522,7 +523,6 @@ class BaseV2KeyFetcher(KeyFetcher):
|
|||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.config = hs.config
|
||||
|
||||
async def process_v2_response(
|
||||
self, from_server: str, response_json: JsonDict, time_added_ms: int
|
||||
|
@ -626,7 +626,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.key_servers = self.config.key.key_servers
|
||||
self.key_servers = hs.config.key.key_servers
|
||||
|
||||
async def _fetch_keys(
|
||||
self, keys_to_fetch: List[_FetchKeyRequest]
|
||||
|
@ -775,7 +775,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
|||
|
||||
keys.setdefault(server_name, {}).update(processed_response)
|
||||
|
||||
await self.store.store_server_verify_keys(
|
||||
await self.store.store_server_signature_keys(
|
||||
perspective_name, time_now_ms, added_keys
|
||||
)
|
||||
|
||||
|
|
|
@ -198,10 +198,17 @@ class _EventInternalMetadata:
|
|||
soft_failed: DictProperty[bool] = DictProperty("soft_failed")
|
||||
proactively_send: DictProperty[bool] = DictProperty("proactively_send")
|
||||
redacted: DictProperty[bool] = DictProperty("redacted")
|
||||
txn_id: DictProperty[str] = DictProperty("txn_id")
|
||||
token_id: DictProperty[int] = DictProperty("token_id")
|
||||
historical: DictProperty[bool] = DictProperty("historical")
|
||||
|
||||
txn_id: DictProperty[str] = DictProperty("txn_id")
|
||||
"""The transaction ID, if it was set when the event was created."""
|
||||
|
||||
token_id: DictProperty[int] = DictProperty("token_id")
|
||||
"""The access token ID of the user who sent this event, if any."""
|
||||
|
||||
device_id: DictProperty[str] = DictProperty("device_id")
|
||||
"""The device ID of the user who sent this event, if any."""
|
||||
|
||||
# XXX: These are set by StreamWorkerStore._set_before_and_after.
|
||||
# I'm pretty sure that these are never persisted to the database, so shouldn't
|
||||
# be here
|
||||
|
|
|
@ -339,6 +339,7 @@ def serialize_event(
|
|||
time_now_ms: int,
|
||||
*,
|
||||
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
|
||||
msc3970_enabled: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Serialize event for clients
|
||||
|
||||
|
@ -346,6 +347,8 @@ def serialize_event(
|
|||
e
|
||||
time_now_ms
|
||||
config: Event serialization config
|
||||
msc3970_enabled: Whether MSC3970 is enabled. It changes whether we should
|
||||
include the `transaction_id` in the event's `unsigned` section.
|
||||
|
||||
Returns:
|
||||
The serialized event dictionary.
|
||||
|
@ -368,27 +371,43 @@ def serialize_event(
|
|||
|
||||
if "redacted_because" in e.unsigned:
|
||||
d["unsigned"]["redacted_because"] = serialize_event(
|
||||
e.unsigned["redacted_because"], time_now_ms, config=config
|
||||
e.unsigned["redacted_because"],
|
||||
time_now_ms,
|
||||
config=config,
|
||||
msc3970_enabled=msc3970_enabled,
|
||||
)
|
||||
|
||||
# If we have a txn_id saved in the internal_metadata, we should include it in the
|
||||
# unsigned section of the event if it was sent by the same session as the one
|
||||
# requesting the event.
|
||||
# There is a special case for guests, because they only have one access token
|
||||
# without associated access_token_id, so we always include the txn_id for events
|
||||
# they sent.
|
||||
txn_id = getattr(e.internal_metadata, "txn_id", None)
|
||||
txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None)
|
||||
if txn_id is not None and config.requester is not None:
|
||||
event_token_id = getattr(e.internal_metadata, "token_id", None)
|
||||
if config.requester.user.to_string() == e.sender and (
|
||||
(
|
||||
event_token_id is not None
|
||||
and config.requester.access_token_id is not None
|
||||
and event_token_id == config.requester.access_token_id
|
||||
# For the MSC3970 rules to be applied, we *need* to have the device ID in the
|
||||
# event internal metadata. Since we were not recording them before, if it hasn't
|
||||
# been recorded, we fallback to the old behaviour.
|
||||
event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None)
|
||||
if msc3970_enabled and event_device_id is not None:
|
||||
if event_device_id == config.requester.device_id:
|
||||
d["unsigned"]["transaction_id"] = txn_id
|
||||
|
||||
else:
|
||||
# The pre-MSC3970 behaviour is to only include the transaction ID if the
|
||||
# event was sent from the same access token. For regular users, we can use
|
||||
# the access token ID to determine this. For guests, we can't, but since
|
||||
# each guest only has one access token, we can just check that the event was
|
||||
# sent by the same user as the one requesting the event.
|
||||
event_token_id: Optional[int] = getattr(
|
||||
e.internal_metadata, "token_id", None
|
||||
)
|
||||
or config.requester.is_guest
|
||||
):
|
||||
d["unsigned"]["transaction_id"] = txn_id
|
||||
if config.requester.user.to_string() == e.sender and (
|
||||
(
|
||||
event_token_id is not None
|
||||
and config.requester.access_token_id is not None
|
||||
and event_token_id == config.requester.access_token_id
|
||||
)
|
||||
or config.requester.is_guest
|
||||
):
|
||||
d["unsigned"]["transaction_id"] = txn_id
|
||||
|
||||
# invite_room_state and knock_room_state are a list of stripped room state events
|
||||
# that are meant to provide metadata about a room to an invitee/knocker. They are
|
||||
|
@ -419,6 +438,9 @@ class EventClientSerializer:
|
|||
clients.
|
||||
"""
|
||||
|
||||
def __init__(self, *, msc3970_enabled: bool = False):
|
||||
self._msc3970_enabled = msc3970_enabled
|
||||
|
||||
def serialize_event(
|
||||
self,
|
||||
event: Union[JsonDict, EventBase],
|
||||
|
@ -443,7 +465,9 @@ class EventClientSerializer:
|
|||
if not isinstance(event, EventBase):
|
||||
return event
|
||||
|
||||
serialized_event = serialize_event(event, time_now, config=config)
|
||||
serialized_event = serialize_event(
|
||||
event, time_now, config=config, msc3970_enabled=self._msc3970_enabled
|
||||
)
|
||||
|
||||
# Check if there are any bundled aggregations to include with the event.
|
||||
if bundle_aggregations:
|
||||
|
@ -501,7 +525,9 @@ class EventClientSerializer:
|
|||
# `sender` of the edit; however MSC3925 proposes extending it to the whole
|
||||
# of the edit, which is what we do here.
|
||||
serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
|
||||
event_aggregations.replace, time_now, config=config
|
||||
event_aggregations.replace,
|
||||
time_now,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Include any threaded replies to this event.
|
||||
|
|
|
@ -280,15 +280,11 @@ class FederationClient(FederationBase):
|
|||
logger.debug("backfill transaction_data=%r", transaction_data)
|
||||
|
||||
if not isinstance(transaction_data, dict):
|
||||
# TODO we probably want an exception type specific to federation
|
||||
# client validation.
|
||||
raise TypeError("Backfill transaction_data is not a dict.")
|
||||
raise InvalidResponseError("Backfill transaction_data is not a dict.")
|
||||
|
||||
transaction_data_pdus = transaction_data.get("pdus")
|
||||
if not isinstance(transaction_data_pdus, list):
|
||||
# TODO we probably want an exception type specific to federation
|
||||
# client validation.
|
||||
raise TypeError("transaction_data.pdus is not a list.")
|
||||
raise InvalidResponseError("transaction_data.pdus is not a list.")
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
|
|
|
@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
@trace
|
||||
async def on_claim_client_keys(
|
||||
self, origin: str, content: JsonDict
|
||||
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
|
||||
) -> Dict[str, Any]:
|
||||
query = []
|
||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||
|
@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
|
|||
query.append((user_id, device_id, algorithm))
|
||||
|
||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query, always_include_fallback_keys=always_include_fallback_keys
|
||||
)
|
||||
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for result in results:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import urllib
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
|
@ -42,18 +43,21 @@ from synapse.api.urls import (
|
|||
)
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.http.matrixfederationclient import ByteParser
|
||||
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import ExceptionBundle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransportLayerClient:
|
||||
"""Sends federation HTTP requests to other servers"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.server_name = hs.hostname
|
||||
self.client = hs.get_federation_http_client()
|
||||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||
|
@ -133,7 +137,7 @@ class TransportLayerClient:
|
|||
|
||||
async def backfill(
|
||||
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
|
||||
) -> Optional[JsonDict]:
|
||||
) -> Optional[Union[JsonDict, list]]:
|
||||
"""Requests `limit` previous PDUs in a given context before list of
|
||||
PDUs.
|
||||
|
||||
|
@ -388,6 +392,7 @@ class TransportLayerClient:
|
|||
# server was just having a momentary blip, the room will be out of
|
||||
# sync.
|
||||
ignore_backoff=True,
|
||||
parser=LegacyJsonSendParser(),
|
||||
)
|
||||
|
||||
async def send_leave_v2(
|
||||
|
@ -445,7 +450,11 @@ class TransportLayerClient:
|
|||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
return await self.client.put_json(
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
parser=LegacyJsonSendParser(),
|
||||
)
|
||||
|
||||
async def send_invite_v2(
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
|
|||
from synapse.federation.transport.server.federation import (
|
||||
FEDERATION_SERVLET_CLASSES,
|
||||
FederationAccountStatusServlet,
|
||||
FederationUnstableClientKeysClaimServlet,
|
||||
)
|
||||
from synapse.http.server import HttpServer, JsonResource
|
||||
from synapse.http.servlet import (
|
||||
|
@ -298,6 +299,11 @@ def register_servlets(
|
|||
and not hs.config.experimental.msc3720_enabled
|
||||
):
|
||||
continue
|
||||
if (
|
||||
servletclass == FederationUnstableClientKeysClaimServlet
|
||||
and not hs.config.experimental.msc3983_appservice_otk_claims
|
||||
):
|
||||
continue
|
||||
|
||||
servletclass(
|
||||
hs=hs,
|
||||
|
|
|
@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
|
|||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
response = await self.handler.on_claim_client_keys(origin, content)
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
||||
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
|
||||
always includes fallback keys in the response.
|
||||
"""
|
||||
|
||||
PREFIX = FEDERATION_UNSTABLE_PREFIX
|
||||
PATH = "/user/keys/claim"
|
||||
CATEGORY = "Federation requests"
|
||||
|
||||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
||||
|
|
|
@ -842,9 +842,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query: Iterable[Tuple[str, str, str]]
|
||||
) -> Tuple[
|
||||
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
|
||||
]:
|
||||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
"""Claim one time keys from application services.
|
||||
|
||||
Users which are exclusively owned by an application service are sent a
|
||||
|
@ -856,7 +854,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
Returns:
|
||||
A tuple of:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
|
||||
A copy of the input which has not been fulfilled (either because
|
||||
they are not appservice users or the appservice does not support
|
||||
|
@ -897,12 +895,11 @@ class ApplicationServicesHandler:
|
|||
)
|
||||
|
||||
# Patch together the results -- they are all independent (since they
|
||||
# require exclusive control over the users). They get returned as a list
|
||||
# and the caller combines them.
|
||||
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
|
||||
# require exclusive control over the users, which is the outermost key).
|
||||
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for success, result in results:
|
||||
if success:
|
||||
claimed_keys.append(result[0])
|
||||
claimed_keys.update(result[0])
|
||||
missing.extend(result[1])
|
||||
|
||||
return claimed_keys, missing
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
|
@ -921,12 +920,8 @@ class DeviceListWorkerUpdater:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationMultiUserDevicesResyncRestServlet,
|
||||
ReplicationUserDevicesResyncRestServlet,
|
||||
)
|
||||
|
||||
self._user_device_resync_client = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
self._multi_user_device_resync_client = (
|
||||
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
|
@ -948,37 +943,7 @@ class DeviceListWorkerUpdater:
|
|||
# Shortcut empty requests
|
||||
return {}
|
||||
|
||||
try:
|
||||
return await self._multi_user_device_resync_client(user_ids=user_ids)
|
||||
except SynapseError as err:
|
||||
if not (
|
||||
err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
|
||||
):
|
||||
raise
|
||||
|
||||
# Fall back to single requests
|
||||
result: Dict[str, Optional[JsonDict]] = {}
|
||||
for user_id in user_ids:
|
||||
result[user_id] = await self._user_device_resync_client(user_id=user_id)
|
||||
return result
|
||||
|
||||
async def user_device_resync(
|
||||
self, user_id: str, mark_failed_as_stale: bool = True
|
||||
) -> Optional[JsonDict]:
|
||||
"""Fetches all devices for a user and updates the device cache with them.
|
||||
|
||||
Args:
|
||||
user_id: The user's id whose device_list will be updated.
|
||||
mark_failed_as_stale: Whether to mark the user's device list as stale
|
||||
if the attempt to resync failed.
|
||||
Returns:
|
||||
A dict with device info as under the "devices" in the result of this
|
||||
request:
|
||||
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
|
||||
None when we weren't able to fetch the device info for some reason,
|
||||
e.g. due to a connection problem.
|
||||
"""
|
||||
return (await self.multi_user_device_resync([user_id]))[user_id]
|
||||
return await self._multi_user_device_resync_client(user_ids=user_ids)
|
||||
|
||||
|
||||
class DeviceListUpdater(DeviceListWorkerUpdater):
|
||||
|
@ -1131,7 +1096,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
|||
)
|
||||
|
||||
if resync:
|
||||
await self.user_device_resync(user_id)
|
||||
await self.multi_user_device_resync([user_id])
|
||||
else:
|
||||
# Simply update the single device, since we know that is the only
|
||||
# change (because of the single prev_id matching the current cache)
|
||||
|
@ -1198,10 +1163,9 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
|||
for user_id in need_resync:
|
||||
try:
|
||||
# Try to resync the current user's devices list.
|
||||
result = await self.user_device_resync(
|
||||
user_id=user_id,
|
||||
mark_failed_as_stale=False,
|
||||
)
|
||||
result = (await self.multi_user_device_resync([user_id], False))[
|
||||
user_id
|
||||
]
|
||||
|
||||
# user_device_resync only returns a result if it managed to
|
||||
# successfully resync and update the database. Updating the table
|
||||
|
@ -1260,18 +1224,6 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
|
|||
|
||||
return result
|
||||
|
||||
async def user_device_resync(
|
||||
self, user_id: str, mark_failed_as_stale: bool = True
|
||||
) -> Optional[JsonDict]:
|
||||
result, failed = await self._user_device_resync_returning_failed(user_id)
|
||||
|
||||
if failed and mark_failed_as_stale:
|
||||
# Mark the remote user's device list as stale so we know we need to retry
|
||||
# it later.
|
||||
await self.store.mark_remote_users_device_caches_as_stale((user_id,))
|
||||
|
||||
return result
|
||||
|
||||
async def _user_device_resync_returning_failed(
|
||||
self, user_id: str
|
||||
) -> Tuple[Optional[JsonDict], bool]:
|
||||
|
|
|
@ -25,7 +25,9 @@ from synapse.logging.opentracing import (
|
|||
log_kv,
|
||||
set_tag,
|
||||
)
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationMultiUserDevicesResyncRestServlet,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -71,12 +73,12 @@ class DeviceMessageHandler:
|
|||
# sync. We do all device list resyncing on the master instance, so if
|
||||
# we're on a worker we hit the device resync replication API.
|
||||
if hs.config.worker.worker_app is None:
|
||||
self._user_device_resync = (
|
||||
hs.get_device_handler().device_list_updater.user_device_resync
|
||||
self._multi_user_device_resync = (
|
||||
hs.get_device_handler().device_list_updater.multi_user_device_resync
|
||||
)
|
||||
else:
|
||||
self._user_device_resync = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
self._multi_user_device_resync = (
|
||||
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
|
||||
# a rate limiter for room key requests. The keys are
|
||||
|
@ -198,7 +200,7 @@ class DeviceMessageHandler:
|
|||
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
|
||||
|
||||
# Immediately attempt a resync in the background
|
||||
run_in_background(self._user_device_resync, user_id=sender_user_id)
|
||||
run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
|
||||
|
||||
async def send_device_message(
|
||||
self,
|
||||
|
|
|
@ -563,7 +563,9 @@ class E2eKeysHandler:
|
|||
return ret
|
||||
|
||||
async def claim_local_one_time_keys(
|
||||
self, local_query: List[Tuple[str, str, str]]
|
||||
self,
|
||||
local_query: List[Tuple[str, str, str]],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||
"""Claim one time keys for local users.
|
||||
|
||||
|
@ -573,6 +575,7 @@ class E2eKeysHandler:
|
|||
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
|
@ -583,24 +586,73 @@ class E2eKeysHandler:
|
|||
# If the application services have not provided any keys via the C-S
|
||||
# API, query it directly for one-time keys.
|
||||
if self._query_appservices_for_otks:
|
||||
# TODO Should this query for fallback keys of uploaded OTKs if
|
||||
# always_include_fallback_keys is True? The MSC is ambiguous.
|
||||
(
|
||||
appservice_results,
|
||||
not_found,
|
||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||
else:
|
||||
appservice_results = []
|
||||
appservice_results = {}
|
||||
|
||||
# Calculate which user ID / device ID / algorithm tuples to get fallback
|
||||
# keys for. This can be either only missing results *or* all results
|
||||
# (which don't already have a fallback key).
|
||||
if always_include_fallback_keys:
|
||||
# Build the fallback query as any part of the original query where
|
||||
# the appservice didn't respond with a fallback key.
|
||||
fallback_query = []
|
||||
|
||||
# Iterate each item in the original query and search the results
|
||||
# from the appservice for that user ID / device ID. If it is found,
|
||||
# check if any of the keys match the requested algorithm & are a
|
||||
# fallback key.
|
||||
for user_id, device_id, algorithm in local_query:
|
||||
# Check if the appservice responded for this query.
|
||||
as_result = appservice_results.get(user_id, {}).get(device_id, {})
|
||||
found_otk = False
|
||||
for key_id, key_json in as_result.items():
|
||||
if key_id.startswith(f"{algorithm}:"):
|
||||
# A OTK or fallback key was found for this query.
|
||||
found_otk = True
|
||||
# A fallback key was found for this query, no need to
|
||||
# query further.
|
||||
if key_json.get("fallback", False):
|
||||
break
|
||||
|
||||
else:
|
||||
# No fallback key was found from appservices, query for it.
|
||||
# Only mark the fallback key as used if no OTK was found
|
||||
# (from either the database or appservices).
|
||||
mark_as_used = not found_otk and not any(
|
||||
key_id.startswith(f"{algorithm}:")
|
||||
for key_id in otk_results.get(user_id, {})
|
||||
.get(device_id, {})
|
||||
.keys()
|
||||
)
|
||||
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
|
||||
|
||||
else:
|
||||
# All fallback keys get marked as used.
|
||||
fallback_query = [
|
||||
(user_id, device_id, algorithm, True)
|
||||
for user_id, device_id, algorithm in not_found
|
||||
]
|
||||
|
||||
# For each user that does not have a one-time keys available, see if
|
||||
# there is a fallback key.
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||
|
||||
# Return the results in order, each item from the input query should
|
||||
# only appear once in the combined list.
|
||||
return (otk_results, *appservice_results, fallback_results)
|
||||
return (otk_results, appservice_results, fallback_results)
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
||||
self,
|
||||
query: Dict[str, Dict[str, Dict[str, str]]],
|
||||
timeout: Optional[int],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> JsonDict:
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
@ -617,7 +669,9 @@ class E2eKeysHandler:
|
|||
set_tag("local_key_query", str(local_query))
|
||||
set_tag("remote_key_query", str(remote_queries))
|
||||
|
||||
results = await self.claim_local_one_time_keys(local_query)
|
||||
results = await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys
|
||||
)
|
||||
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
|
@ -625,7 +679,9 @@ class E2eKeysHandler:
|
|||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
|
||||
json_result.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
).update({key_id: key})
|
||||
|
||||
# Remote failures.
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
|
|
|
@ -70,7 +70,9 @@ from synapse.logging.opentracing import (
|
|||
trace,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.replication.http.devices import (
|
||||
ReplicationMultiUserDevicesResyncRestServlet,
|
||||
)
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationFederationSendEventsRestServlet,
|
||||
)
|
||||
|
@ -167,8 +169,8 @@ class FederationEventHandler:
|
|||
|
||||
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
|
||||
if hs.config.worker.worker_app:
|
||||
self._user_device_resync = (
|
||||
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||
self._multi_user_device_resync = (
|
||||
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
self._device_list_updater = hs.get_device_handler().device_list_updater
|
||||
|
@ -1487,9 +1489,11 @@ class FederationEventHandler:
|
|||
|
||||
# Immediately attempt a resync in the background
|
||||
if self._config.worker.worker_app:
|
||||
await self._user_device_resync(user_id=sender)
|
||||
await self._multi_user_device_resync(user_ids=[sender])
|
||||
else:
|
||||
await self._device_list_updater.user_device_resync(sender)
|
||||
await self._device_list_updater.multi_user_device_resync(
|
||||
user_ids=[sender]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resync device for %s", sender)
|
||||
|
||||
|
|
|
@ -561,6 +561,8 @@ class EventCreationHandler:
|
|||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
self._msc3970_enabled = hs.config.experimental.msc3970_enabled
|
||||
|
||||
async def create_event(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -701,9 +703,16 @@ class EventCreationHandler:
|
|||
if require_consent and not is_exempt:
|
||||
await self.assert_accepted_privacy_policy(requester)
|
||||
|
||||
# Save the access token ID, the device ID and the transaction ID in the event
|
||||
# internal metadata. This is useful to determine if we should echo the
|
||||
# transaction_id in events.
|
||||
# See `synapse.events.utils.EventClientSerializer.serialize_event`
|
||||
if requester.access_token_id is not None:
|
||||
builder.internal_metadata.token_id = requester.access_token_id
|
||||
|
||||
if requester.device_id is not None:
|
||||
builder.internal_metadata.device_id = requester.device_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
|
@ -897,12 +906,31 @@ class EventCreationHandler:
|
|||
Returns:
|
||||
An event if one could be found, None otherwise.
|
||||
"""
|
||||
|
||||
if self._msc3970_enabled and requester.device_id:
|
||||
# When MSC3970 is enabled, we lookup for events sent by the same device first,
|
||||
# and fallback to the old behaviour if none were found.
|
||||
existing_event_id = (
|
||||
await self.store.get_event_id_from_transaction_id_and_device_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.device_id,
|
||||
txn_id,
|
||||
)
|
||||
)
|
||||
if existing_event_id:
|
||||
return await self.store.get_event(existing_event_id)
|
||||
|
||||
# Pre-MSC3970, we looked up for events that were sent by the same session by
|
||||
# using the access token ID.
|
||||
if requester.access_token_id:
|
||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.access_token_id,
|
||||
txn_id,
|
||||
existing_event_id = (
|
||||
await self.store.get_event_id_from_transaction_id_and_token_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.access_token_id,
|
||||
txn_id,
|
||||
)
|
||||
)
|
||||
if existing_event_id:
|
||||
return await self.store.get_event(existing_event_id)
|
||||
|
|
|
@ -169,6 +169,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
self.request_ratelimiter = hs.get_request_ratelimiter()
|
||||
hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
|
||||
|
||||
self._msc3970_enabled = hs.config.experimental.msc3970_enabled
|
||||
|
||||
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
|
||||
"""Notify the rate limiter that a room join has occurred.
|
||||
|
||||
|
@ -399,13 +401,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
# Check if we already have an event with a matching transaction ID. (We
|
||||
# do this check just before we persist an event as well, but may as well
|
||||
# do it up front for efficiency.)
|
||||
if txn_id and requester.access_token_id:
|
||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.access_token_id,
|
||||
txn_id,
|
||||
)
|
||||
if txn_id:
|
||||
existing_event_id = None
|
||||
if self._msc3970_enabled and requester.device_id:
|
||||
# When MSC3970 is enabled, we lookup for events sent by the same device
|
||||
# first, and fallback to the old behaviour if none were found.
|
||||
existing_event_id = (
|
||||
await self.store.get_event_id_from_transaction_id_and_device_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.device_id,
|
||||
txn_id,
|
||||
)
|
||||
)
|
||||
|
||||
if requester.access_token_id and not existing_event_id:
|
||||
existing_event_id = (
|
||||
await self.store.get_event_id_from_transaction_id_and_token_id(
|
||||
room_id,
|
||||
requester.user.to_string(),
|
||||
requester.access_token_id,
|
||||
txn_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_event_id:
|
||||
event_pos = await self.store.get_position_for_event(existing_event_id)
|
||||
return existing_event_id, event_pos.stream
|
||||
|
|
|
@ -17,7 +17,6 @@ import codecs
|
|||
import logging
|
||||
import random
|
||||
import sys
|
||||
import typing
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO, StringIO
|
||||
|
@ -30,9 +29,11 @@ from typing import (
|
|||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TextIO,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
@ -183,20 +184,61 @@ class MatrixFederationRequest:
|
|||
return self.json
|
||||
|
||||
|
||||
class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||
class _BaseJsonParser(ByteParser[T]):
|
||||
"""A parser that buffers the response and tries to parse it as JSON."""
|
||||
|
||||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self, validator: Optional[Callable[[Optional[object]], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
validator: A callable which takes the parsed JSON value and returns
|
||||
true if the value is valid.
|
||||
"""
|
||||
self._buffer = StringIO()
|
||||
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||
self._validator = validator
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
return self._binary_wrapper.write(data)
|
||||
|
||||
def finish(self) -> Union[JsonDict, list]:
|
||||
return json_decoder.decode(self._buffer.getvalue())
|
||||
def finish(self) -> T:
|
||||
result = json_decoder.decode(self._buffer.getvalue())
|
||||
if self._validator is not None and not self._validator(result):
|
||||
raise ValueError(
|
||||
f"Received incorrect JSON value: {result.__class__.__name__}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class JsonParser(_BaseJsonParser[JsonDict]):
|
||||
"""A parser that buffers the response and tries to parse it as a JSON object."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(self._validate)
|
||||
|
||||
@staticmethod
|
||||
def _validate(v: Any) -> bool:
|
||||
return isinstance(v, dict)
|
||||
|
||||
|
||||
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
|
||||
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(self._validate)
|
||||
|
||||
@staticmethod
|
||||
def _validate(v: Any) -> bool:
|
||||
# Match [integer, JSON dict]
|
||||
return (
|
||||
isinstance(v, list)
|
||||
and len(v) == 2
|
||||
and type(v[0]) == int
|
||||
and isinstance(v[1], dict)
|
||||
)
|
||||
|
||||
|
||||
async def _handle_response(
|
||||
|
@ -313,9 +355,7 @@ async def _handle_response(
|
|||
class BinaryIOWrapper:
|
||||
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
||||
|
||||
def __init__(
|
||||
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
|
||||
):
|
||||
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
|
||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
||||
self.file = file
|
||||
|
||||
|
@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
|
|||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
|
|||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
):
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> Union[JsonDict, T]:
|
||||
"""Sends the specified json data using PUT
|
||||
|
||||
Args:
|
||||
|
@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
|
|||
_sec_timeout = self.default_timeout
|
||||
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
parser = cast(ByteParser[T], JsonParser())
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor,
|
||||
|
@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
"""Sends the specified json data using POST
|
||||
|
||||
Args:
|
||||
|
@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
|
|||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
):
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> Union[JsonDict, T]:
|
||||
"""GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
|
@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
|
|||
_sec_timeout = self.default_timeout
|
||||
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
parser = cast(ByteParser[T], JsonParser())
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor,
|
||||
|
@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
"""Send a DELETE request to the remote expecting some json response
|
||||
|
||||
Args:
|
||||
|
|
|
@ -28,62 +28,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
"""Ask master to resync the device list for a user by contacting their
|
||||
server.
|
||||
|
||||
This must happen on master so that the results can be correctly cached in
|
||||
the database and streamed to workers.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/user_device_resync/:user_id
|
||||
|
||||
{}
|
||||
|
||||
Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
|
||||
response, e.g.:
|
||||
|
||||
{
|
||||
"user_id": "@alice:example.org",
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": { ... },
|
||||
"device_display_name": "Alice's Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "user_device_resync"
|
||||
PATH_ARGS = ("user_id",)
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
|
||||
handler = hs.get_device_handler()
|
||||
assert isinstance(handler, DeviceHandler)
|
||||
self.device_list_updater = handler.device_list_updater
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request, content: JsonDict, user_id: str
|
||||
) -> Tuple[int, Optional[JsonDict]]:
|
||||
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
||||
|
||||
return 200, user_devices
|
||||
|
||||
|
||||
class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
|
||||
"""Ask master to resync the device list for multiple users from the same
|
||||
remote server by contacting their server.
|
||||
|
@ -216,6 +160,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
|
|||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
||||
ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
|
||||
ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||
|
@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
|
|||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
class UnstableOneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
|
||||
fallback keys in the response.
|
||||
"""
|
||||
|
||||
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
|
||||
CATEGORY = "Encryption requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
|
@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
KeyQueryServlet(hs).register(http_server)
|
||||
KeyChangesServlet(hs).register(http_server)
|
||||
OneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc3983_appservice_otk_claims:
|
||||
UnstableOneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.worker.worker_app is None:
|
||||
SigningKeyUploadServlet(hs).register(http_server)
|
||||
SignaturesUploadServlet(hs).register(http_server)
|
||||
|
|
|
@ -50,6 +50,8 @@ class HttpTransactionCache:
|
|||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||
|
||||
self._msc3970_enabled = hs.config.experimental.msc3970_enabled
|
||||
|
||||
def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
|
||||
"""A helper function which returns a transaction key that can be used
|
||||
with TransactionCache for idempotent requests.
|
||||
|
@ -58,6 +60,7 @@ class HttpTransactionCache:
|
|||
requests to the same endpoint. The key is formed from the HTTP request
|
||||
path and attributes from the requester: the access_token_id for regular users,
|
||||
the user ID for guest users, and the appservice ID for appservice users.
|
||||
With MSC3970, for regular users, the key is based on the user ID and device ID.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
@ -67,11 +70,21 @@ class HttpTransactionCache:
|
|||
"""
|
||||
assert request.path is not None
|
||||
path: str = request.path.decode("utf8")
|
||||
|
||||
if requester.is_guest:
|
||||
assert requester.user is not None, "Guest requester must have a user ID set"
|
||||
return (path, "guest", requester.user)
|
||||
|
||||
elif requester.app_service is not None:
|
||||
return (path, "appservice", requester.app_service.id)
|
||||
|
||||
# With MSC3970, we use the user ID and device ID as the transaction key
|
||||
elif self._msc3970_enabled:
|
||||
assert requester.user, "Requester must have a user"
|
||||
assert requester.device_id, "Requester must have a device_id"
|
||||
return (path, "user", requester.user, requester.device_id)
|
||||
|
||||
# Otherwise, the pre-MSC3970 behaviour is to use the access token ID
|
||||
else:
|
||||
assert (
|
||||
requester.access_token_id is not None
|
||||
|
|
|
@ -155,7 +155,7 @@ class RemoteKey(RestServlet):
|
|||
for key_id in key_ids:
|
||||
store_queries.append((server_name, key_id, None))
|
||||
|
||||
cached = await self.store.get_server_keys_json(store_queries)
|
||||
cached = await self.store.get_server_keys_json_for_remote(store_queries)
|
||||
|
||||
json_results: Set[bytes] = set()
|
||||
|
||||
|
|
|
@ -762,7 +762,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
|
||||
@cache_in_self
|
||||
def get_event_client_serializer(self) -> EventClientSerializer:
|
||||
return EventClientSerializer()
|
||||
return EventClientSerializer(
|
||||
msc3970_enabled=self.config.experimental.msc3970_enabled
|
||||
)
|
||||
|
||||
@cache_in_self
|
||||
def get_password_policy_handler(self) -> PasswordPolicyHandler:
|
||||
|
|
|
@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
return results, missing
|
||||
|
||||
async def claim_e2e_fallback_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str]]
|
||||
self, query_list: Iterable[Tuple[str, str, str, bool]]
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Take a list of fallback keys out of the database.
|
||||
|
||||
Args:
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
query_list: An iterable of tuples of
|
||||
(user ID, device ID, algorithm, whether the key should be marked as used).
|
||||
|
||||
Returns:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
"""
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
for user_id, device_id, algorithm, mark_as_used in query_list:
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
|
@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
used = row["used"]
|
||||
|
||||
# Mark fallback key as used if not already.
|
||||
if not used:
|
||||
if not used and mark_as_used:
|
||||
await self.db_pool.simple_update_one(
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
|
|
|
@ -127,6 +127,8 @@ class PersistEventsStore:
|
|||
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
|
||||
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
|
||||
|
||||
self._msc3970_enabled = hs.config.experimental.msc3970_enabled
|
||||
|
||||
@trace
|
||||
async def _persist_events_and_state_updates(
|
||||
self,
|
||||
|
@ -977,23 +979,43 @@ class PersistEventsStore:
|
|||
) -> None:
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
|
||||
|
||||
to_insert = []
|
||||
inserted_ts = self._clock.time_msec()
|
||||
to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = []
|
||||
to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = []
|
||||
for event, _ in events_and_contexts:
|
||||
token_id = getattr(event.internal_metadata, "token_id", None)
|
||||
txn_id = getattr(event.internal_metadata, "txn_id", None)
|
||||
if token_id and txn_id:
|
||||
to_insert.append(
|
||||
(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.sender,
|
||||
token_id,
|
||||
txn_id,
|
||||
self._clock.time_msec(),
|
||||
)
|
||||
)
|
||||
token_id = getattr(event.internal_metadata, "token_id", None)
|
||||
device_id = getattr(event.internal_metadata, "device_id", None)
|
||||
|
||||
if to_insert:
|
||||
if txn_id is not None:
|
||||
if token_id is not None:
|
||||
to_insert_token_id.append(
|
||||
(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.sender,
|
||||
token_id,
|
||||
txn_id,
|
||||
inserted_ts,
|
||||
)
|
||||
)
|
||||
|
||||
if device_id is not None:
|
||||
to_insert_device_id.append(
|
||||
(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.sender,
|
||||
device_id,
|
||||
txn_id,
|
||||
inserted_ts,
|
||||
)
|
||||
)
|
||||
|
||||
# Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events.
|
||||
# Since this is an experimental flag, we still store the mapping even if the
|
||||
# flag is disabled.
|
||||
if to_insert_token_id:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_txn_id",
|
||||
|
@ -1005,7 +1027,25 @@ class PersistEventsStore:
|
|||
"txn_id",
|
||||
"inserted_ts",
|
||||
),
|
||||
values=to_insert,
|
||||
values=to_insert_token_id,
|
||||
)
|
||||
|
||||
# With MSC3970, we rely on the device_id instead to scope the txn_id for events.
|
||||
# We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970
|
||||
# behaviour would allow for a UNIQUE constraint violation on this table
|
||||
if to_insert_device_id and self._msc3970_enabled:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_txn_id_device_id",
|
||||
keys=(
|
||||
"event_id",
|
||||
"room_id",
|
||||
"user_id",
|
||||
"device_id",
|
||||
"txn_id",
|
||||
"inserted_ts",
|
||||
),
|
||||
values=to_insert_device_id,
|
||||
)
|
||||
|
||||
async def update_current_state(
|
||||
|
@ -1127,11 +1167,15 @@ class PersistEventsStore:
|
|||
# been inserted into room_memberships.
|
||||
txn.execute_batch(
|
||||
"""INSERT INTO current_state_events
|
||||
(room_id, type, state_key, event_id, membership)
|
||||
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||
(room_id, type, state_key, event_id, membership, event_stream_ordering)
|
||||
VALUES (
|
||||
?, ?, ?, ?,
|
||||
(SELECT membership FROM room_memberships WHERE event_id = ?),
|
||||
(SELECT stream_ordering FROM events WHERE event_id = ?)
|
||||
)
|
||||
""",
|
||||
[
|
||||
(room_id, key[0], key[1], ev_id, ev_id)
|
||||
(room_id, key[0], key[1], ev_id, ev_id, ev_id)
|
||||
for key, ev_id in to_insert.items()
|
||||
],
|
||||
)
|
||||
|
@ -1158,11 +1202,15 @@ class PersistEventsStore:
|
|||
if to_insert:
|
||||
txn.execute_batch(
|
||||
"""INSERT INTO local_current_membership
|
||||
(room_id, user_id, event_id, membership)
|
||||
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||
(room_id, user_id, event_id, membership, event_stream_ordering)
|
||||
VALUES (
|
||||
?, ?, ?,
|
||||
(SELECT membership FROM room_memberships WHERE event_id = ?),
|
||||
(SELECT stream_ordering FROM events WHERE event_id = ?)
|
||||
)
|
||||
""",
|
||||
[
|
||||
(room_id, key[1], ev_id, ev_id)
|
||||
(room_id, key[1], ev_id, ev_id, ev_id)
|
||||
for key, ev_id in to_insert.items()
|
||||
if key[0] == EventTypes.Member and self.is_mine_id(key[1])
|
||||
],
|
||||
|
@ -1768,6 +1816,7 @@ class PersistEventsStore:
|
|||
table="room_memberships",
|
||||
keys=(
|
||||
"event_id",
|
||||
"event_stream_ordering",
|
||||
"user_id",
|
||||
"sender",
|
||||
"room_id",
|
||||
|
@ -1778,6 +1827,7 @@ class PersistEventsStore:
|
|||
values=[
|
||||
(
|
||||
event.event_id,
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.state_key,
|
||||
event.user_id,
|
||||
event.room_id,
|
||||
|
@ -1810,6 +1860,7 @@ class PersistEventsStore:
|
|||
keyvalues={"room_id": event.room_id, "user_id": event.state_key},
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"event_stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"membership": event.membership,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -2022,7 +2022,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
async def get_event_id_from_transaction_id(
|
||||
async def get_event_id_from_transaction_id_and_token_id(
|
||||
self, room_id: str, user_id: str, token_id: int, txn_id: str
|
||||
) -> Optional[str]:
|
||||
"""Look up if we have already persisted an event for the transaction ID,
|
||||
|
@ -2038,7 +2038,26 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
desc="get_event_id_from_transaction_id",
|
||||
desc="get_event_id_from_transaction_id_and_token_id",
|
||||
)
|
||||
|
||||
async def get_event_id_from_transaction_id_and_device_id(
|
||||
self, room_id: str, user_id: str, device_id: str, txn_id: str
|
||||
) -> Optional[str]:
|
||||
"""Look up if we have already persisted an event for the transaction ID,
|
||||
returning the event ID if so.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="event_txn_id_device_id",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"txn_id": txn_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
desc="get_event_id_from_transaction_id_and_device_id",
|
||||
)
|
||||
|
||||
async def get_already_persisted_events(
|
||||
|
@ -2068,7 +2087,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
# Check if this is a duplicate of an event we've already
|
||||
# persisted.
|
||||
existing = await self.get_event_id_from_transaction_id(
|
||||
existing = await self.get_event_id_from_transaction_id_and_token_id(
|
||||
event.room_id, event.sender, token_id, txn_id
|
||||
)
|
||||
if existing:
|
||||
|
@ -2084,11 +2103,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""Cleans out transaction id mappings older than 24hrs."""
|
||||
|
||||
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
|
||||
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
|
||||
sql = """
|
||||
DELETE FROM event_txn_id
|
||||
WHERE inserted_ts < ?
|
||||
"""
|
||||
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
|
||||
txn.execute(sql, (one_day_ago,))
|
||||
|
||||
sql = """
|
||||
DELETE FROM event_txn_id_device_id
|
||||
WHERE inserted_ts < ?
|
||||
"""
|
||||
txn.execute(sql, (one_day_ago,))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
|
|
@ -14,10 +14,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
@ -36,15 +38,16 @@ class KeyStore(SQLBaseStore):
|
|||
"""Persistence for signature verification keys"""
|
||||
|
||||
@cached()
|
||||
def _get_server_verify_key(
|
||||
def _get_server_signature_key(
|
||||
self, server_name_and_key_id: Tuple[str, str]
|
||||
) -> FetchKeyResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
|
||||
cached_method_name="_get_server_signature_key",
|
||||
list_name="server_name_and_key_ids",
|
||||
)
|
||||
async def get_server_verify_keys(
|
||||
async def get_server_signature_keys(
|
||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||
"""
|
||||
|
@ -62,10 +65,12 @@ class KeyStore(SQLBaseStore):
|
|||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||
|
||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||
sql = (
|
||||
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
|
||||
"FROM server_signature_keys WHERE 1=0"
|
||||
) + " OR (server_name=? AND key_id=?)" * len(batch)
|
||||
sql = """
|
||||
SELECT server_name, key_id, verify_key, ts_valid_until_ms
|
||||
FROM server_signature_keys WHERE 1=0
|
||||
""" + " OR (server_name=? AND key_id=?)" * len(
|
||||
batch
|
||||
)
|
||||
|
||||
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
|
||||
|
||||
|
@ -89,9 +94,9 @@ class KeyStore(SQLBaseStore):
|
|||
_get_keys(txn, batch)
|
||||
return keys
|
||||
|
||||
return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
|
||||
return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
|
||||
|
||||
async def store_server_verify_keys(
|
||||
async def store_server_signature_keys(
|
||||
self,
|
||||
from_server: str,
|
||||
ts_added_ms: int,
|
||||
|
@ -119,7 +124,7 @@ class KeyStore(SQLBaseStore):
|
|||
)
|
||||
)
|
||||
# invalidate takes a tuple corresponding to the params of
|
||||
# _get_server_verify_key. _get_server_verify_key only takes one
|
||||
# _get_server_signature_key. _get_server_signature_key only takes one
|
||||
# param, which is itself the 2-tuple (server_name, key_id).
|
||||
invalidations.append((server_name, key_id))
|
||||
|
||||
|
@ -134,10 +139,10 @@ class KeyStore(SQLBaseStore):
|
|||
"verify_key",
|
||||
),
|
||||
value_values=value_values,
|
||||
desc="store_server_verify_keys",
|
||||
desc="store_server_signature_keys",
|
||||
)
|
||||
|
||||
invalidate = self._get_server_verify_key.invalidate
|
||||
invalidate = self._get_server_signature_key.invalidate
|
||||
for i in invalidations:
|
||||
invalidate((i,))
|
||||
|
||||
|
@ -180,7 +185,75 @@ class KeyStore(SQLBaseStore):
|
|||
desc="store_server_keys_json",
|
||||
)
|
||||
|
||||
# invalidate takes a tuple corresponding to the params of
|
||||
# _get_server_keys_json. _get_server_keys_json only takes one
|
||||
# param, which is itself the 2-tuple (server_name, key_id).
|
||||
self._get_server_keys_json.invalidate((((server_name, key_id),)))
|
||||
|
||||
@cached()
|
||||
def _get_server_keys_json(
|
||||
self, server_name_and_key_id: Tuple[str, str]
|
||||
) -> FetchKeyResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
|
||||
)
|
||||
async def get_server_keys_json(
|
||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||
"""
|
||||
Args:
|
||||
server_name_and_key_ids:
|
||||
iterable of (server_name, key-id) tuples to fetch keys for
|
||||
|
||||
Returns:
|
||||
A map from (server_name, key_id) -> FetchKeyResult, or None if the
|
||||
key is unknown
|
||||
"""
|
||||
keys = {}
|
||||
|
||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
|
||||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||
|
||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||
sql = """
|
||||
SELECT server_name, key_id, key_json, ts_valid_until_ms
|
||||
FROM server_keys_json WHERE 1=0
|
||||
""" + " OR (server_name=? AND key_id=?)" * len(
|
||||
batch
|
||||
)
|
||||
|
||||
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
|
||||
|
||||
for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
|
||||
if ts_valid_until_ms is None:
|
||||
# Old keys may be stored with a ts_valid_until_ms of null,
|
||||
# in which case we treat this as if it was set to `0`, i.e.
|
||||
# it won't match key requests that define a minimum
|
||||
# `ts_valid_until_ms`.
|
||||
ts_valid_until_ms = 0
|
||||
|
||||
# The entire signed JSON response is stored in server_keys_json,
|
||||
# fetch out the bits needed.
|
||||
key_json = json.loads(bytes(key_json_bytes))
|
||||
key_base64 = key_json["verify_keys"][key_id]["key"]
|
||||
|
||||
keys[(server_name, key_id)] = FetchKeyResult(
|
||||
verify_key=decode_verify_key_bytes(
|
||||
key_id, decode_base64(key_base64)
|
||||
),
|
||||
valid_until_ts=ts_valid_until_ms,
|
||||
)
|
||||
|
||||
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||
for batch in batch_iter(server_name_and_key_ids, 50):
|
||||
_get_keys(txn, batch)
|
||||
return keys
|
||||
|
||||
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
||||
|
||||
async def get_server_keys_json_for_remote(
|
||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
||||
"""Retrieve the key json for a list of server_keys and key ids.
|
||||
|
@ -188,8 +261,10 @@ class KeyStore(SQLBaseStore):
|
|||
that server, key_id, and source triplet entry will be an empty list.
|
||||
The JSON is returned as a byte array so that it can be efficiently
|
||||
used in an HTTP response.
|
||||
|
||||
Args:
|
||||
server_keys: List of (server_name, key_id, source) triplets.
|
||||
|
||||
Returns:
|
||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
||||
"""
|
||||
|
|
|
@ -428,14 +428,16 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
"partial_state_events",
|
||||
"partial_state_rooms_servers",
|
||||
"partial_state_rooms",
|
||||
# Note: the _membership(s) tables have foreign keys to the `events` table
|
||||
# so must be deleted first.
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
"events",
|
||||
"federation_inbound_events_staging",
|
||||
"local_current_membership",
|
||||
"receipts_graph",
|
||||
"receipts_linearized",
|
||||
"room_aliases",
|
||||
"room_depth",
|
||||
"room_memberships",
|
||||
"room_stats_state",
|
||||
"room_stats_current",
|
||||
"room_stats_earliest_token",
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 74 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 75 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -91,13 +91,19 @@ Changes in SCHEMA_VERSION = 74:
|
|||
- A query on `event_stream_ordering` column has now been disambiguated (i.e. the
|
||||
codebase can handle the `current_state_events`, `local_current_memberships` and
|
||||
`room_memberships` tables having an `event_stream_ordering` column).
|
||||
|
||||
Changes in SCHEMA_VERSION = 75:
|
||||
- The `event_stream_ordering` column in membership tables (`current_state_events`,
|
||||
`local_current_membership` & `room_memberships`) is now being populated for new
|
||||
rows. When the background job to populate historical rows lands this will
|
||||
become the compat schema version.
|
||||
"""
|
||||
|
||||
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
# The threads_id column must exist for event_push_actions, event_push_summary,
|
||||
# receipts_linearized, and receipts_graph.
|
||||
73
|
||||
# Queries against `event_stream_ordering` columns in membership tables must
|
||||
# be disambiguated.
|
||||
74
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2022 Beeper
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which
|
||||
-- we use to improve database performance by reduring JOINs.
|
||||
|
||||
-- NOTE: these are set to NOT VALID to prevent locks while adding the column on large existing tables,
|
||||
-- which will be validated in a later migration. For all new/updated rows the FKEY will be checked.
|
||||
|
||||
ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE current_state_events ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
|
||||
ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE local_current_membership ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
|
||||
ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE room_memberships ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
|
@ -0,0 +1,23 @@
|
|||
/* Copyright 2022 Beeper
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which
|
||||
-- we use to improve database performance by reduring JOINs.
|
||||
|
||||
-- NOTE: sqlite does not support ADD CONSTRAINT so we add the new columns with FK constraint as-is
|
||||
|
||||
ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Beeper
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This migration adds triggers to the room membership tables to enforce consistency.
|
||||
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
||||
"""
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Cursor
|
||||
|
||||
|
||||
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
|
||||
# Complain if the `event_stream_ordering` in membership tables doesn't match
|
||||
# the `stream_ordering` row with the same `event_id` in `events`.
|
||||
if isinstance(database_engine, Sqlite3Engine):
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TRIGGER IF NOT EXISTS {table}_bad_event_stream_ordering
|
||||
BEFORE INSERT ON {table}
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
SELECT RAISE(ABORT, 'Incorrect event_stream_ordering in {table}')
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
);
|
||||
END;
|
||||
"""
|
||||
)
|
||||
elif isinstance(database_engine, PostgresEngine):
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION check_event_stream_ordering() RETURNS trigger AS $BODY$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
) THEN
|
||||
RAISE EXCEPTION 'Incorrect event_stream_ordering';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$BODY$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TRIGGER check_event_stream_ordering BEFORE INSERT OR UPDATE ON {table}
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE check_event_stream_ordering()
|
||||
"""
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown database engine")
|
|
@ -0,0 +1,53 @@
|
|||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- For MSC3970, in addition to the (room_id, user_id, token_id, txn_id) -> event_id mapping for each local event,
|
||||
-- we also store the (room_id, user_id, device_id, txn_id) -> event_id mapping.
|
||||
--
|
||||
-- This adds a new event_txn_id_device_id table.
|
||||
|
||||
-- A map of recent events persisted with transaction IDs. Used to deduplicate
|
||||
-- send event requests with the same transaction ID.
|
||||
--
|
||||
-- Note: with MSC3970, transaction IDs are scoped to the
|
||||
-- room ID/user ID/device ID that was used to make the request.
|
||||
--
|
||||
-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
|
||||
-- event or device we don't want to try and de-duplicate the event.
|
||||
CREATE TABLE IF NOT EXISTS event_txn_id_device_id (
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
txn_id TEXT NOT NULL,
|
||||
inserted_ts BIGINT NOT NULL,
|
||||
FOREIGN KEY (event_id)
|
||||
REFERENCES events (event_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_id, device_id)
|
||||
REFERENCES devices (user_id, device_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- This ensures that there is only one mapping per event_id.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_event_id
|
||||
ON event_txn_id_device_id(event_id);
|
||||
|
||||
-- This ensures that there is only one mapping per (room_id, user_id, device_id, txn_id) tuple.
|
||||
-- Events are usually looked up using this index.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_txn_id
|
||||
ON event_txn_id_device_id(room_id, user_id, device_id, txn_id);
|
||||
|
||||
-- This table is cleaned up regularly, removing the oldest entries, hence this index.
|
||||
CREATE INDEX IF NOT EXISTS event_txn_id_device_id_ts
|
||||
ON event_txn_id_device_id(inserted_ts);
|
|
@ -190,10 +190,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
kr = keyring.Keyring(self.hs)
|
||||
|
||||
key1 = signedjson.key.generate_signing_key("1")
|
||||
r = self.hs.get_datastores().main.store_server_verify_keys(
|
||||
r = self.hs.get_datastores().main.store_server_keys_json(
|
||||
"server9",
|
||||
int(time.time() * 1000),
|
||||
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)},
|
||||
get_key_id(key1),
|
||||
from_server="test",
|
||||
ts_now_ms=int(time.time() * 1000),
|
||||
ts_expires_ms=1000,
|
||||
# The entire response gets signed & stored, just include the bits we
|
||||
# care about.
|
||||
key_json_bytes=canonicaljson.encode_canonical_json(
|
||||
{
|
||||
"verify_keys": {
|
||||
get_key_id(key1): {
|
||||
"key": encode_verify_key_base64(get_verify_key(key1))
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
self.get_success(r)
|
||||
|
||||
|
@ -280,17 +293,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
mock_fetcher = Mock()
|
||||
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
|
||||
|
||||
kr = keyring.Keyring(
|
||||
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
|
||||
)
|
||||
|
||||
key1 = signedjson.key.generate_signing_key("1")
|
||||
r = self.hs.get_datastores().main.store_server_verify_keys(
|
||||
r = self.hs.get_datastores().main.store_server_signature_keys(
|
||||
"server9",
|
||||
int(time.time() * 1000),
|
||||
# None is not a valid value in FetchKeyResult, but we're abusing this
|
||||
# API to insert null values into the database. The nulls get converted
|
||||
# to 0 when fetched in KeyStore.get_server_verify_keys.
|
||||
# to 0 when fetched in KeyStore.get_server_signature_keys.
|
||||
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
|
||||
)
|
||||
self.get_success(r)
|
||||
|
@ -298,27 +307,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
|||
json1: JsonDict = {}
|
||||
signedjson.sign.sign_json(json1, "server9", key1)
|
||||
|
||||
# should fail immediately on an unsigned object
|
||||
d = kr.verify_json_for_server("server9", {}, 0)
|
||||
self.get_failure(d, SynapseError)
|
||||
|
||||
# should fail on a signed object with a non-zero minimum_valid_until_ms,
|
||||
# as it tries to refetch the keys and fails.
|
||||
d = kr.verify_json_for_server("server9", json1, 500)
|
||||
self.get_failure(d, SynapseError)
|
||||
|
||||
# We expect the keyring tried to refetch the key once.
|
||||
mock_fetcher.get_keys.assert_called_once_with(
|
||||
"server9", [get_key_id(key1)], 500
|
||||
)
|
||||
|
||||
# should succeed on a signed object with a 0 minimum_valid_until_ms
|
||||
d = kr.verify_json_for_server(
|
||||
"server9",
|
||||
json1,
|
||||
0,
|
||||
d = self.hs.get_datastores().main.get_server_signature_keys(
|
||||
[("server9", get_key_id(key1))]
|
||||
)
|
||||
self.get_success(d)
|
||||
result = self.get_success(d)
|
||||
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
|
||||
|
||||
def test_verify_json_dedupes_key_requests(self) -> None:
|
||||
"""Two requests for the same key should be deduped."""
|
||||
|
@ -464,7 +458,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
|
@ -582,7 +578,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
|
@ -703,7 +701,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
|||
# check that the perspectives store is correctly updated
|
||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
||||
key_json = self.get_success(
|
||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||
[lookup_triplet]
|
||||
)
|
||||
)
|
||||
res_keys = key_json[lookup_triplet]
|
||||
self.assertEqual(len(res_keys), 1)
|
||||
|
|
|
@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
|
|
@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
res2 = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# claiming an OTK again should return the same fallback key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||
)
|
||||
|
||||
def test_fallback_key_always_returned(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
otk = {"alg1:k2": "key2"}
|
||||
|
||||
# we shouldn't have any unused fallback keys yet
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(res, [])
|
||||
|
||||
# Upload a OTK & fallback key.
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"one_time_keys": otk, "fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
# we should now have an unused alg1 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Claiming an OTK and requesting to always return the fallback key should
|
||||
# return both.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
|
||||
},
|
||||
)
|
||||
|
||||
# This should not mark the key as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Claiming an OTK again should return only the fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||
)
|
||||
|
||||
# And mark it as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, [])
|
||||
|
||||
def test_replace_master_key(self) -> None:
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
|
@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
}
|
||||
},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
|
||||
def test_query_appservice_with_fallback(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id_1 = "xyz"
|
||||
fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
|
||||
otk = {"alg1:k2": {"desc": "key2"}}
|
||||
as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
|
||||
as_otk = {"alg1:k4": {"desc": "key4"}}
|
||||
|
||||
# Inject an appservice interested in this user.
|
||||
appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
sender="@as_main:test",
|
||||
)
|
||||
self.hs.get_datastores().main.services_cache = [appservice]
|
||||
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
|
||||
[appservice]
|
||||
)
|
||||
|
||||
# Setup a response.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which will ask the appservice and do nothing else.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Now upload a fallback key.
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(res, [])
|
||||
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id_1,
|
||||
{"fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
# we should now have an unused alg1 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# The appservice will return only the OTK.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: as_otk}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which should return the OTK from the appservice and the
|
||||
# uploaded fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {device_id_1: {**as_otk, **fallback_key}}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# But the fallback key should not be marked as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Now upload a OTK.
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id_1,
|
||||
{"one_time_keys": otk},
|
||||
)
|
||||
)
|
||||
|
||||
# Claim OTKs, which will return information only from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
|
||||
},
|
||||
)
|
||||
|
||||
# But the fallback key should not be marked as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Finally, return only the fallback key from the appservice.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: as_fallback_key}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which will return only the fallback key from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id_1: as_fallback_key}},
|
||||
},
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
|
||||
def test_query_local_devices_appservice(self) -> None:
|
||||
"""Test that querying of appservices for keys overrides responses from the database."""
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
|
|||
|
||||
from synapse.api.errors import RequestSendFailed
|
||||
from synapse.http.matrixfederationclient import (
|
||||
JsonParser,
|
||||
ByteParser,
|
||||
MatrixFederationHttpClient,
|
||||
MatrixFederationRequest,
|
||||
)
|
||||
|
@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
|
|||
while not test_d.called:
|
||||
protocol.dataReceived(b"a" * chunk_size)
|
||||
sent += chunk_size
|
||||
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
||||
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||
|
||||
self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
||||
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||
|
||||
f = self.failureResultOf(test_d)
|
||||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
|
|
|
@ -37,13 +37,13 @@ KEY_2 = decode_verify_key_base64(
|
|||
|
||||
|
||||
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
def test_get_server_verify_keys(self) -> None:
|
||||
def test_get_server_signature_keys(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
key_id_1 = "ed25519:key1"
|
||||
key_id_2 = "ed25519:KEY_ID_2"
|
||||
self.get_success(
|
||||
store.store_server_verify_keys(
|
||||
store.store_server_signature_keys(
|
||||
"from_server",
|
||||
10,
|
||||
{
|
||||
|
@ -54,7 +54,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
res = self.get_success(
|
||||
store.get_server_verify_keys(
|
||||
store.get_server_signature_keys(
|
||||
[
|
||||
("server1", key_id_1),
|
||||
("server1", key_id_2),
|
||||
|
@ -87,7 +87,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
key_id_2 = "ed25519:key2"
|
||||
|
||||
self.get_success(
|
||||
store.store_server_verify_keys(
|
||||
store.store_server_signature_keys(
|
||||
"from_server",
|
||||
0,
|
||||
{
|
||||
|
@ -98,7 +98,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
res = self.get_success(
|
||||
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||
)
|
||||
self.assertEqual(len(res.keys()), 2)
|
||||
|
||||
|
@ -111,20 +111,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
self.assertEqual(res2.valid_until_ts, 200)
|
||||
|
||||
# we should be able to look up the same thing again without a db hit
|
||||
res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
|
||||
res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
|
||||
self.assertEqual(len(res.keys()), 1)
|
||||
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
|
||||
|
||||
new_key_2 = signedjson.key.get_verify_key(
|
||||
signedjson.key.generate_signing_key("key2")
|
||||
)
|
||||
d = store.store_server_verify_keys(
|
||||
d = store.store_server_signature_keys(
|
||||
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
|
||||
)
|
||||
self.get_success(d)
|
||||
|
||||
res = self.get_success(
|
||||
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||
)
|
||||
self.assertEqual(len(res.keys()), 2)
|
||||
|
||||
|
|
|
@ -267,7 +267,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
|||
# Resync the device list.
|
||||
device_handler = self.hs.get_device_handler()
|
||||
self.get_success(
|
||||
device_handler.device_list_updater.user_device_resync(remote_user_id),
|
||||
device_handler.device_list_updater.multi_user_device_resync(
|
||||
[remote_user_id]
|
||||
),
|
||||
)
|
||||
|
||||
# Retrieve the cross-signing keys for this user.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import gc
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
|
@ -53,6 +54,7 @@ from twisted.web.server import Request
|
|||
from synapse import events
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.config._base import Config, RootConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
|
@ -67,7 +69,6 @@ from synapse.logging.context import (
|
|||
)
|
||||
from synapse.rest import RegisterServletsFunc
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
|
@ -124,6 +125,63 @@ def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
|
|||
return _around
|
||||
|
||||
|
||||
_TConfig = TypeVar("_TConfig", Config, RootConfig)
|
||||
|
||||
|
||||
def deepcopy_config(config: _TConfig) -> _TConfig:
|
||||
new_config: _TConfig
|
||||
|
||||
if isinstance(config, RootConfig):
|
||||
new_config = config.__class__(config.config_files) # type: ignore[arg-type]
|
||||
else:
|
||||
new_config = config.__class__(config.root)
|
||||
|
||||
for attr_name in config.__dict__:
|
||||
if attr_name.startswith("__") or attr_name == "root":
|
||||
continue
|
||||
attr = getattr(config, attr_name)
|
||||
if isinstance(attr, Config):
|
||||
new_attr = deepcopy_config(attr)
|
||||
else:
|
||||
new_attr = attr
|
||||
|
||||
setattr(new_config, attr_name, new_attr)
|
||||
|
||||
return new_config
|
||||
|
||||
|
||||
_make_homeserver_config_obj_cache: Dict[str, Union[RootConfig, Config]] = {}
|
||||
|
||||
|
||||
def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
|
||||
"""Creates a :class:`HomeServerConfig` instance with the given configuration dict.
|
||||
|
||||
This is equivalent to::
|
||||
|
||||
config_obj = HomeServerConfig()
|
||||
config_obj.parse_config_dict(config, "", "")
|
||||
|
||||
but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
|
||||
to avoid validating the whole configuration every time.
|
||||
"""
|
||||
cache_key = json.dumps(config)
|
||||
|
||||
if cache_key in _make_homeserver_config_obj_cache:
|
||||
# Cache hit: reuse the existing instance
|
||||
config_obj = _make_homeserver_config_obj_cache[cache_key]
|
||||
else:
|
||||
# Cache miss; create the actual instance
|
||||
config_obj = HomeServerConfig()
|
||||
config_obj.parse_config_dict(config, "", "")
|
||||
|
||||
# Add to the cache
|
||||
_make_homeserver_config_obj_cache[cache_key] = config_obj
|
||||
|
||||
assert isinstance(config_obj, RootConfig)
|
||||
|
||||
return deepcopy_config(config_obj)
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
|
||||
attributes on both itself and its individual test methods, to override the
|
||||
|
@ -528,8 +586,7 @@ class HomeserverTestCase(TestCase):
|
|||
config = kwargs["config"]
|
||||
|
||||
# Parse the config from a config dict into a HomeServerConfig
|
||||
config_obj = HomeServerConfig()
|
||||
config_obj.parse_config_dict(config, "", "")
|
||||
config_obj = make_homeserver_config_obj(config)
|
||||
kwargs["config"] = config_obj
|
||||
|
||||
async def run_bg_updates() -> None:
|
||||
|
@ -790,15 +847,23 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||
|
||||
self.get_success(
|
||||
hs.get_datastores().main.store_server_verify_keys(
|
||||
hs.get_datastores().main.store_server_keys_json(
|
||||
self.OTHER_SERVER_NAME,
|
||||
verify_key_id,
|
||||
from_server=self.OTHER_SERVER_NAME,
|
||||
ts_added_ms=clock.time_msec(),
|
||||
verify_keys={
|
||||
(self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult(
|
||||
verify_key=verify_key,
|
||||
valid_until_ts=clock.time_msec() + 10000,
|
||||
),
|
||||
},
|
||||
ts_now_ms=clock.time_msec(),
|
||||
ts_expires_ms=clock.time_msec() + 10000,
|
||||
key_json_bytes=canonicaljson.encode_canonical_json(
|
||||
{
|
||||
"verify_keys": {
|
||||
verify_key_id: {
|
||||
"key": signedjson.key.encode_verify_key_base64(
|
||||
verify_key
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -131,6 +131,9 @@ def default_config(
|
|||
# the test signing key is just an arbitrary ed25519 key to keep the config
|
||||
# parser happy
|
||||
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
|
||||
# Disable trusted key servers, otherwise unit tests might try to actually
|
||||
# reach out to matrix.org.
|
||||
"trusted_key_servers": [],
|
||||
"event_cache_size": 1,
|
||||
"enable_registration": True,
|
||||
"enable_registration_captcha": False,
|
||||
|
|
Loading…
Reference in New Issue