diff --git a/.github/workflows/docs-pr-netlify.yaml b/.github/workflows/docs-pr-netlify.yaml
index ef7a38144e..1704b3ce93 100644
--- a/.github/workflows/docs-pr-netlify.yaml
+++ b/.github/workflows/docs-pr-netlify.yaml
@@ -14,7 +14,7 @@ jobs:
# There's a 'download artifact' action, but it hasn't been updated for the workflow_run action
# (https://github.com/actions/download-artifact/issues/60) so instead we get this mess:
- name: 📥 Download artifact
- uses: dawidd6/action-download-artifact@bd10f381a96414ce2b13a11bfa89902ba7cea07f # v2.24.3
+ uses: dawidd6/action-download-artifact@b59d8c6a6c5c6c6437954f470d963c0b20ea7415 # v2.25.0
with:
workflow: docs-pr.yaml
run_id: ${{ github.event.workflow_run.id }}
diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml
index 99fc2cee08..6da7c22e4c 100644
--- a/.github/workflows/latest_deps.yml
+++ b/.github/workflows/latest_deps.yml
@@ -27,7 +27,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
@@ -61,7 +61,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
@@ -134,7 +134,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
diff --git a/.github/workflows/poetry_lockfile.yaml b/.github/workflows/poetry_lockfile.yaml
new file mode 100644
index 0000000000..ae4d27f2de
--- /dev/null
+++ b/.github/workflows/poetry_lockfile.yaml
@@ -0,0 +1,24 @@
+on:
+ push:
+ branches: ["develop", "release-*"]
+ paths:
+ - poetry.lock
+ pull_request:
+ paths:
+ - poetry.lock
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ check-sdists:
+ name: "Check locked dependencies have sdists"
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.x'
+ - run: pip install tomli
+ - run: ./scripts-dev/check_locked_deps_have_sdists.py
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index e945ffe7f3..cfafeaadc9 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -112,7 +112,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
components: clippy
@@ -134,7 +134,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: nightly-2022-12-01
components: clippy
@@ -154,7 +154,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
components: rustfmt
@@ -221,7 +221,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
- uses: Swatinem/rust-cache@v2
@@ -266,7 +266,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
- uses: Swatinem/rust-cache@v2
@@ -386,7 +386,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
- uses: Swatinem/rust-cache@v2
@@ -531,7 +531,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
- uses: Swatinem/rust-cache@v2
@@ -562,7 +562,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: 1.58.1
- uses: Swatinem/rust-cache@v2
@@ -585,7 +585,7 @@ jobs:
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: nightly-2022-12-01
- uses: Swatinem/rust-cache@v2
diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml
index a59c8dac09..db514571c4 100644
--- a/.github/workflows/twisted_trunk.yml
+++ b/.github/workflows/twisted_trunk.yml
@@ -18,7 +18,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
@@ -43,7 +43,7 @@ jobs:
- run: sudo apt-get -qq install xmlsec1
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
@@ -82,7 +82,7 @@ jobs:
- uses: actions/checkout@v3
- name: Install Rust
- uses: dtolnay/rust-toolchain@9cd00a88a73addc8617065438eff914dd08d0955
+ uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
diff --git a/CHANGES.md b/CHANGES.md
index a2cb957f16..01b81fe174 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,98 @@
+Synapse 1.78.0rc1 (2023-02-21)
+==============================
+
+Features
+--------
+
+- Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758). ([\#14964](https://github.com/matrix-org/synapse/issues/14964))
+- Add account data to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.78/usage/administration/admin_faq.html#how-can-i-export-user-data). ([\#14969](https://github.com/matrix-org/synapse/issues/14969))
+- Implement [MSC3873](https://github.com/matrix-org/matrix-spec-proposals/pull/3873) to disambiguate push rule keys with dots in them. ([\#15004](https://github.com/matrix-org/synapse/issues/15004))
+- Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments. ([\#15034](https://github.com/matrix-org/synapse/issues/15034))
+- Tag opentracing spans for federation requests with the name of the worker serving the request. ([\#15042](https://github.com/matrix-org/synapse/issues/15042))
+- Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition. ([\#15045](https://github.com/matrix-org/synapse/issues/15045))
+- Remove spurious `dont_notify` action from the defaults for the `.m.rule.reaction` pushrule. ([\#15073](https://github.com/matrix-org/synapse/issues/15073))
+- Update the error code returned when user sends a duplicate annotation. ([\#15075](https://github.com/matrix-org/synapse/issues/15075))
+
+
+Bugfixes
+--------
+
+- Prevent clients from reporting nonexistent events. ([\#13779](https://github.com/matrix-org/synapse/issues/13779))
+- Return spec-compliant JSON errors when unknown endpoints are requested. ([\#14605](https://github.com/matrix-org/synapse/issues/14605))
+- Fix a long-standing bug where the room aliases returned could be corrupted. ([\#15038](https://github.com/matrix-org/synapse/issues/15038))
+- Fix a bug introduced in Synapse 1.76.0 where partially-joined rooms could not be deleted using the [purge room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). ([\#15068](https://github.com/matrix-org/synapse/issues/15068))
+- Fix a long-standing bug where federated joins would fail if the first server in the list of servers to try is not in the room. ([\#15074](https://github.com/matrix-org/synapse/issues/15074))
+- Fix a bug introduced in Synapse v1.74.0 where searching with colons when using ICU for search term tokenisation would fail with an error. ([\#15079](https://github.com/matrix-org/synapse/issues/15079))
+- Reduce the likelihood of a rare race condition where rejoining a restricted room over federation would fail. ([\#15080](https://github.com/matrix-org/synapse/issues/15080))
+- Fix a bug introduced in Synapse 1.76 where workers would fail to start if the `health` listener was configured. ([\#15096](https://github.com/matrix-org/synapse/issues/15096))
+- Fix a bug introduced in Synapse 1.75 where the [portdb script](https://matrix-org.github.io/synapse/release-v1.78/postgres.html#porting-from-sqlite) would fail to run after a room had been faster-joined. ([\#15108](https://github.com/matrix-org/synapse/issues/15108))
+
+
+Improved Documentation
+----------------------
+
+- Document how to start Synapse with Poetry. Contributed by @thezaidbintariq. ([\#14892](https://github.com/matrix-org/synapse/issues/14892))
+- Update delegation documentation to clarify that SRV DNS delegation does not eliminate all needs to serve files from .well-known locations. Contributed by @williamkray. ([\#14959](https://github.com/matrix-org/synapse/issues/14959))
+- Document how to start Synapse in the contributing guide. ([\#15022](https://github.com/matrix-org/synapse/issues/15022))
+- Fix a mistake in registration_shared_secret_path docs. ([\#15078](https://github.com/matrix-org/synapse/issues/15078))
+- Refer to a more recent blog post on the [Database Maintenance Tools](https://matrix-org.github.io/synapse/latest/usage/administration/database_maintenance_tools.html) page. Contributed by @jahway603. ([\#15083](https://github.com/matrix-org/synapse/issues/15083))
+
+
+Internal Changes
+----------------
+
+- Re-type hint some collections as read-only. ([\#13755](https://github.com/matrix-org/synapse/issues/13755))
+- Faster joins: don't stall when another user joins during a partial-state room resync. ([\#14606](https://github.com/matrix-org/synapse/issues/14606))
+- Add a class `UnpersistedEventContext` to allow for the batching up of storing state groups. ([\#14675](https://github.com/matrix-org/synapse/issues/14675))
+- Add a check to ensure that locked dependencies have source distributions available. ([\#14742](https://github.com/matrix-org/synapse/issues/14742))
+- Tweak comment on `_is_local_room_accessible` as part of room visibility in `/hierarchy` to clarify the condition for a room being visible. ([\#14834](https://github.com/matrix-org/synapse/issues/14834))
+- Prevent 'WARNING: there is already a transaction in progress' lines appearing in PostgreSQL's logs on some occasions. ([\#14840](https://github.com/matrix-org/synapse/issues/14840))
+- Use `StrCollection` to avoid potential bugs with `Collection[str]`. ([\#14929](https://github.com/matrix-org/synapse/issues/14929))
+- Improve performance of `/sync` in a few situations. ([\#14973](https://github.com/matrix-org/synapse/issues/14973))
+- Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room. ([\#14977](https://github.com/matrix-org/synapse/issues/14977))
+- Skip calculating unread push actions in /sync when enable_push is false. ([\#14980](https://github.com/matrix-org/synapse/issues/14980))
+- Add a schema dump symlinks inside `contrib`, to make it easier for IDEs to interrogate Synapse's database schema. ([\#14982](https://github.com/matrix-org/synapse/issues/14982))
+- Improve type hints. ([\#15008](https://github.com/matrix-org/synapse/issues/15008), [\#15026](https://github.com/matrix-org/synapse/issues/15026), [\#15027](https://github.com/matrix-org/synapse/issues/15027), [\#15028](https://github.com/matrix-org/synapse/issues/15028), [\#15031](https://github.com/matrix-org/synapse/issues/15031), [\#15035](https://github.com/matrix-org/synapse/issues/15035), [\#15052](https://github.com/matrix-org/synapse/issues/15052), [\#15072](https://github.com/matrix-org/synapse/issues/15072), [\#15084](https://github.com/matrix-org/synapse/issues/15084))
+- Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC. ([\#15037](https://github.com/matrix-org/synapse/issues/15037))
+- Avoid mutating a cached value in `get_user_devices_from_cache`. ([\#15040](https://github.com/matrix-org/synapse/issues/15040))
+- Fix a rare exception in logs on start up. ([\#15041](https://github.com/matrix-org/synapse/issues/15041))
+- Update pyo3-log to v0.8.1. ([\#15043](https://github.com/matrix-org/synapse/issues/15043))
+- Avoid mutating cached values in `_generate_sync_entry_for_account_data`. ([\#15047](https://github.com/matrix-org/synapse/issues/15047))
+- Refactor arguments of `try_unbind_threepid` and `_try_unbind_threepid_with_id_server` to not use dictionaries. ([\#15053](https://github.com/matrix-org/synapse/issues/15053))
+- Merge debug logging from the hotfixes branch. ([\#15054](https://github.com/matrix-org/synapse/issues/15054))
+- Faster joins: omit device list updates originating from partial state rooms in /sync responses without lazy loading of members enabled. ([\#15069](https://github.com/matrix-org/synapse/issues/15069))
+- Fix clashing database transaction name. ([\#15070](https://github.com/matrix-org/synapse/issues/15070))
+- Upper-bound frozendict dependency. This works around us being unable to test installing our wheels against Python 3.11 in CI. ([\#15114](https://github.com/matrix-org/synapse/issues/15114))
+- Tweak logging for when a worker waits for its view of a replication stream to catch up. ([\#15120](https://github.com/matrix-org/synapse/issues/15120))
+
+Locked dependency updates
+
+- Bump bleach from 5.0.1 to 6.0.0. ([\#15059](https://github.com/matrix-org/synapse/issues/15059))
+- Bump cryptography from 38.0.4 to 39.0.1. ([\#15020](https://github.com/matrix-org/synapse/issues/15020))
+- Bump ruff version from 0.0.230 to 0.0.237. ([\#15033](https://github.com/matrix-org/synapse/issues/15033))
+- Bump dtolnay/rust-toolchain from 9cd00a88a73addc8617065438eff914dd08d0955 to 25dc93b901a87e864900a8aec6c12e9aa794c0c3. ([\#15060](https://github.com/matrix-org/synapse/issues/15060))
+- Bump systemd-python from 234 to 235. ([\#15061](https://github.com/matrix-org/synapse/issues/15061))
+- Bump serde_json from 1.0.92 to 1.0.93. ([\#15062](https://github.com/matrix-org/synapse/issues/15062))
+- Bump types-requests from 2.28.11.8 to 2.28.11.12. ([\#15063](https://github.com/matrix-org/synapse/issues/15063))
+- Bump types-pillow from 9.4.0.5 to 9.4.0.10. ([\#15064](https://github.com/matrix-org/synapse/issues/15064))
+- Bump sentry-sdk from 1.13.0 to 1.15.0. ([\#15065](https://github.com/matrix-org/synapse/issues/15065))
+- Bump types-jsonschema from 4.17.0.3 to 4.17.0.5. ([\#15099](https://github.com/matrix-org/synapse/issues/15099))
+- Bump types-bleach from 5.0.3.1 to 6.0.0.0. ([\#15100](https://github.com/matrix-org/synapse/issues/15100))
+- Bump dtolnay/rust-toolchain from 25dc93b901a87e864900a8aec6c12e9aa794c0c3 to e12eda571dc9a5ee5d58eecf4738ec291c66f295. ([\#15101](https://github.com/matrix-org/synapse/issues/15101))
+- Bump dawidd6/action-download-artifact from 2.24.3 to 2.25.0. ([\#15102](https://github.com/matrix-org/synapse/issues/15102))
+- Bump types-pillow from 9.4.0.10 to 9.4.0.13. ([\#15104](https://github.com/matrix-org/synapse/issues/15104))
+- Bump types-setuptools from 67.1.0.0 to 67.3.0.1. ([\#15105](https://github.com/matrix-org/synapse/issues/15105))
+
+
+
+
+
+Synapse 1.77.0 (2023-02-14)
+===========================
+
+No significant changes since 1.77.0rc2.
+
+
Synapse 1.77.0rc2 (2023-02-10)
==============================
@@ -57,7 +152,7 @@ Internal Changes
- Preparatory work for adding a denormalised event stream ordering column in the future. Contributed by Nick @ Beeper (@fizzadar). ([\#14979](https://github.com/matrix-org/synapse/issues/14979), [9cd7610](https://github.com/matrix-org/synapse/commit/9cd7610f86ab5051c9365dd38d1eec405a5f8ca6), [f10caa7](https://github.com/matrix-org/synapse/commit/f10caa73eee0caa91cf373966104d1ededae2aee); see [\#15014](https://github.com/matrix-org/synapse/issues/15014))
- Add tests for `_flatten_dict`. ([\#14981](https://github.com/matrix-org/synapse/issues/14981), [\#15002](https://github.com/matrix-org/synapse/issues/15002))
-Dependabot updates
+Locked dependency updates
- Bump dtolnay/rust-toolchain from e645b0cf01249a964ec099494d38d2da0f0b349f to 9cd00a88a73addc8617065438eff914dd08d0955. ([\#14968](https://github.com/matrix-org/synapse/issues/14968))
- Bump docker/build-push-action from 3 to 4. ([\#14952](https://github.com/matrix-org/synapse/issues/14952))
diff --git a/Cargo.lock b/Cargo.lock
index a9219eac11..1bf76cb863 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -232,9 +232,9 @@ dependencies = [
[[package]]
name = "pyo3-log"
-version = "0.7.0"
+version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e5695ccff5060c13ca1751cf8c857a12da9b0bf0378cb071c5e0326f7c7e4c1b"
+checksum = "f9c8b57fe71fb5dcf38970ebedc2b1531cf1c14b1b9b4c560a182a57e115575c"
dependencies = [
"arc-swap",
"log",
@@ -343,9 +343,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.92"
+version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a"
+checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76"
dependencies = [
"itoa",
"ryu",
diff --git a/contrib/datagrip/README.md b/contrib/datagrip/README.md
new file mode 100644
index 0000000000..bbe4f3a5a3
--- /dev/null
+++ b/contrib/datagrip/README.md
@@ -0,0 +1,28 @@
+# Schema symlinks
+
+This directory contains symlinks to the latest dump of the postgres full schema. This is useful to have, as it allows IDEs to understand our schema and provide autocomplete, linters, inspections, etc.
+
+In particular, the DataGrip functionality in IntelliJ's products seems to only consider files called `*.sql` when defining a schema from DDL; `*.sql.postgres` will be ignored. To get around this we symlink those files to ones ending in `.sql`. We've chosen to ignore the `.sql.sqlite` schema dumps here, as they're not intended for production use (and are much quicker to test against).
+
+## Example
+
+
+## Caveats
+
+- Doesn't include temporary tables created ad-hoc by Synapse.
+- Postgres only. IDEs will likely be confused by SQLite-specific queries.
+- Will not include migrations created after the latest schema dump.
+- Symlinks might confuse checkouts on Windows systems.
+
+## Instructions
+
+### Jetbrains IDEs with DataGrip plugin
+
+- View -> Tool Windows -> Database
+- `+` Icon -> DDL Data Source
+- Pick a name, e.g. `Synapse schema dump`
+- Under sources, click `+`.
+- Add an entry with Path pointing to this directory, and dialect set to PostgreSQL.
+- OK, and OK.
+- IDE should now be aware of the schema.
+- Try control-clicking on a table name in a bit of SQL e.g. in `_get_forgotten_rooms_for_user_txn`.
\ No newline at end of file
diff --git a/contrib/datagrip/common.sql b/contrib/datagrip/common.sql
new file mode 120000
index 0000000000..28c5aa8a1b
--- /dev/null
+++ b/contrib/datagrip/common.sql
@@ -0,0 +1 @@
+../../synapse/storage/schema/common/full_schemas/72/full.sql.postgres
\ No newline at end of file
diff --git a/contrib/datagrip/datagrip-aware-of-schema.png b/contrib/datagrip/datagrip-aware-of-schema.png
new file mode 100644
index 0000000000..653642da91
Binary files /dev/null and b/contrib/datagrip/datagrip-aware-of-schema.png differ
diff --git a/contrib/datagrip/main.sql b/contrib/datagrip/main.sql
new file mode 120000
index 0000000000..eec0a2fb6d
--- /dev/null
+++ b/contrib/datagrip/main.sql
@@ -0,0 +1 @@
+../../synapse/storage/schema/main/full_schemas/72/full.sql.postgres
\ No newline at end of file
diff --git a/contrib/datagrip/schema_version.sql b/contrib/datagrip/schema_version.sql
new file mode 120000
index 0000000000..e1b0985d74
--- /dev/null
+++ b/contrib/datagrip/schema_version.sql
@@ -0,0 +1 @@
+../../synapse/storage/schema/common/schema_version.sql
\ No newline at end of file
diff --git a/contrib/datagrip/state.sql b/contrib/datagrip/state.sql
new file mode 120000
index 0000000000..4de4fbbdf7
--- /dev/null
+++ b/contrib/datagrip/state.sql
@@ -0,0 +1 @@
+../../synapse/storage/schema/state/full_schemas/72/full.sql.postgres
\ No newline at end of file
diff --git a/contrib/docker_compose_workers/README.md b/contrib/docker_compose_workers/README.md
index bdd3dd32e0..d3cdfe5614 100644
--- a/contrib/docker_compose_workers/README.md
+++ b/contrib/docker_compose_workers/README.md
@@ -68,6 +68,7 @@ redis:
enabled: true
host: redis
port: 6379
+ # dbid:
# password:
```
diff --git a/debian/changelog b/debian/changelog
index 461953742b..f9e95ee5e2 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,16 @@
+matrix-synapse-py3 (1.78.0~rc1) stable; urgency=medium
+
+ * Add `matrix-org-archive-keyring` package as recommended.
+ * New Synapse release 1.78.0rc1.
+
+ -- Synapse Packaging team Tue, 21 Feb 2023 14:29:19 +0000
+
+matrix-synapse-py3 (1.77.0) stable; urgency=medium
+
+ * New Synapse release 1.77.0.
+
+ -- Synapse Packaging team Tue, 14 Feb 2023 12:59:02 +0100
+
matrix-synapse-py3 (1.77.0~rc2) stable; urgency=medium
* New Synapse release 1.77.0rc2.
diff --git a/debian/control b/debian/control
index bc628cec08..2ff55db5de 100644
--- a/debian/control
+++ b/debian/control
@@ -37,6 +37,7 @@ Depends:
# so we put perl:Depends in Suggests rather than Depends.
Recommends:
${shlibs1:Recommends},
+ matrix-org-archive-keyring,
Suggests:
sqlite3,
${perl:Depends},
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 7f8c8e22c1..30833f3109 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -235,6 +235,14 @@ The following fields are returned in the JSON response body:
Request:
+```
+POST /_synapse/admin/v1/media/delete?before_ts=
+
+{}
+```
+
+*Deprecated in Synapse v1.78.0:* This API is available at the deprecated endpoint:
+
```
POST /_synapse/admin/v1/media//delete?before_ts=
@@ -243,7 +251,7 @@ POST /_synapse/admin/v1/media//delete?before_ts=
URL Parameters
-* `server_name`: string - The name of your local server (e.g `matrix.org`).
+* `server_name`: string - The name of your local server (e.g `matrix.org`). *Deprecated in Synapse v1.78.0.*
* `before_ts`: string representing a positive integer - Unix timestamp in milliseconds.
Files that were last used before this timestamp will be deleted. It is the timestamp of
last access, not the timestamp when the file was created.
diff --git a/docs/delegate.md b/docs/delegate.md
index ee9cbb3b1c..aee82fcb9a 100644
--- a/docs/delegate.md
+++ b/docs/delegate.md
@@ -73,6 +73,15 @@ It is also possible to do delegation using a SRV DNS record. However, that is ge
not recommended, as it can be difficult to configure the TLS certificates correctly in
this case, and it offers little advantage over `.well-known` delegation.
+Please keep in mind that server delegation is a function of server-server communication,
+and as such using SRV DNS records will not cover use cases involving client-server comms.
+This means setting global client settings (such as a Jitsi endpoint, or disabling
+creating new rooms as encrypted by default, etc) will still require that you serve a file
+from the `https:///.well-known/` endpoints defined in the spec! If you are
+considering using SRV DNS delegation to avoid serving files from this endpoint, consider
+the impact that you will not be able to change those client-based default values globally,
+and will be relegated to the featureset of the configuration of each individual client.
+
However, if you really need it, you can find some documentation on what such a
record should look like and how Synapse will use it in [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names).
diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md
index 36bc884684..925dcd8933 100644
--- a/docs/development/contributing_guide.md
+++ b/docs/development/contributing_guide.md
@@ -78,6 +78,19 @@ poetry install --extras all
This will install the runtime and developer dependencies for the project.
+## Running Synapse via poetry
+
+To start a local instance of Synapse in the locked poetry environment, create a config file:
+
+```sh
+cp docs/sample_config.yaml homeserver.yaml
+```
+
+Now edit homeserver.yaml, and run Synapse with:
+
+```sh
+poetry run python -m synapse.app.homeserver -c homeserver.yaml
+```
# 5. Get in touch.
diff --git a/docs/upgrade.md b/docs/upgrade.md
index bc143444be..15167b8c58 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -88,6 +88,15 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
+# Upgrading to v1.78.0
+
+## Deprecate the `/_synapse/admin/v1/media//delete` admin API
+
+Synapse 1.78.0 replaces the `/_synapse/admin/v1/media//delete`
+admin API with an identical endpoint at `/_synapse/admin/v1/media/delete`. Please
+update your tooling to use the new endpoint. The deprecated version will be removed
+in a future release.
+
# Upgrading to v1.76.0
## Faster joins are enabled by default
@@ -137,6 +146,7 @@ and then do `pip install matrix-synapse[user-search]` for a PyPI install.
Docker images and Debian packages need nothing specific as they already
include or specify ICU as an explicit dependency.
+
# Upgrading to v1.73.0
## Legacy Prometheus metric names have now been removed
diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md
index 7a27741199..925e1d175e 100644
--- a/docs/usage/administration/admin_faq.md
+++ b/docs/usage/administration/admin_faq.md
@@ -71,6 +71,9 @@ output-directory
│ ├───invite_state
│ └───knock_state
└───user_data
+ ├───account_data
+ │ ├───global
+ │ └───
├───connections
├───devices
└───profile
diff --git a/docs/usage/administration/database_maintenance_tools.md b/docs/usage/administration/database_maintenance_tools.md
index 92b805d413..e19380db07 100644
--- a/docs/usage/administration/database_maintenance_tools.md
+++ b/docs/usage/administration/database_maintenance_tools.md
@@ -1,4 +1,4 @@
-This blog post by Victor Berger explains how to use many of the tools listed on this page: https://levans.fr/shrink-synapse-database.html
+_This [blog post by Jackson Chen](https://jacksonchen666.com/posts/2022-12-03/14-33-00/) (Dec 2022) explains how to use many of the tools listed on this page. There is also an [earlier blog by Victor Berger](https://levans.fr/shrink-synapse-database.html) (June 2020), though this may be outdated in places._
# List of useful tools and scripts for maintenance Synapse database:
@@ -15,4 +15,4 @@ The purge history API allows server admins to purge historic events from their d
Tool for compressing (deduplicating) `state_groups_state` table.
## [SQL for analyzing Synapse PostgreSQL database stats](useful_sql_for_admins.md)
-Some easy SQL that reports useful stats about your Synapse database.
\ No newline at end of file
+Some easy SQL that reports useful stats about your Synapse database.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 2883f76a26..58c6955689 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -2232,7 +2232,7 @@ key on startup and store it in this file.
Example configuration:
```yaml
-registration_shared_secret_file: /path/to/secrets/file
+registration_shared_secret_path: /path/to/secrets/file
```
_Added in Synapse 1.67.0._
@@ -3927,6 +3927,9 @@ This setting has the following sub-options:
* `host` and `port`: Optional host and port to use to connect to redis. Defaults to
localhost and 6379
* `password`: Optional password if configured on the Redis instance.
+* `dbid`: Optional redis dbid if needs to connect to specific redis logical db.
+
+ _Added in Synapse 1.78.0._
Example configuration:
```yaml
@@ -3935,6 +3938,7 @@ redis:
host: localhost
port: 6379
password:
+ dbid:
```
---
## Individual worker configuration
diff --git a/docs/workers.md b/docs/workers.md
index bc66f0e1bc..2eb970ffa6 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -160,7 +160,18 @@ recommend the use of `systemd` where available: for information on setting up
[Systemd with Workers](systemd-with-workers/). To use `synctl`, see
[Using synctl with Workers](synctl_workers.md).
+## Start Synapse with Poetry
+The following applies to Synapse installations that have been installed from source using `poetry`.
+
+You can start the main Synapse process with Poetry by running the following command:
+```console
+poetry run synapse_homeserver -c [your homeserver.yaml]
+```
+For worker setups, you can run the following command
+```console
+poetry run synapse_worker -c [your worker.yaml]
+```
## Available worker applications
### `synapse.app.generic_worker`
diff --git a/mypy.ini b/mypy.ini
index 0efafb26b6..94562d0bce 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -31,10 +31,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
-
- |tests/module_api/test_api.py
- |tests/rest/media/v1/test_media_storage.py
- |tests/server.py
)$
[mypy-synapse.federation.transport.client]
@@ -55,87 +51,12 @@ warn_unused_ignores = False
[mypy-synapse.util.caches.treecache]
disallow_untyped_defs = False
-[mypy-synapse.server]
-disallow_untyped_defs = False
-
[mypy-synapse.storage.database]
disallow_untyped_defs = False
-[mypy-tests.*]
-disallow_untyped_defs = False
-
-[mypy-tests.api.*]
-disallow_untyped_defs = True
-
-[mypy-tests.app.*]
-disallow_untyped_defs = True
-
-[mypy-tests.appservice.*]
-disallow_untyped_defs = True
-
-[mypy-tests.config.*]
-disallow_untyped_defs = True
-
-[mypy-tests.crypto.*]
-disallow_untyped_defs = True
-
-[mypy-tests.events.*]
-disallow_untyped_defs = True
-
-[mypy-tests.federation.*]
-disallow_untyped_defs = True
-
-[mypy-tests.handlers.*]
-disallow_untyped_defs = True
-
-[mypy-tests.http.*]
-disallow_untyped_defs = True
-
-[mypy-tests.logging.*]
-disallow_untyped_defs = True
-
-[mypy-tests.metrics.*]
-disallow_untyped_defs = True
-
-[mypy-tests.push.*]
-disallow_untyped_defs = True
-
-[mypy-tests.replication.*]
-disallow_untyped_defs = True
-
-[mypy-tests.rest.*]
-disallow_untyped_defs = True
-
-[mypy-tests.state.test_profile]
-disallow_untyped_defs = True
-
-[mypy-tests.storage.*]
-disallow_untyped_defs = True
-
-[mypy-tests.test_server]
-disallow_untyped_defs = True
-
-[mypy-tests.test_state]
-disallow_untyped_defs = True
-
-[mypy-tests.test_terms_auth]
-disallow_untyped_defs = True
-
-[mypy-tests.types.*]
-disallow_untyped_defs = True
-
-[mypy-tests.util.caches.*]
-disallow_untyped_defs = True
-
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
-[mypy-tests.util.*]
-disallow_untyped_defs = True
-
-[mypy-tests.utils]
-disallow_untyped_defs = True
-
;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here:
diff --git a/poetry.lock b/poetry.lock
index 71095c21ed..4d724ab782 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -127,14 +127,14 @@ uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "bleach"
-version = "5.0.1"
+version = "6.0.0"
description = "An easy safelist-based HTML-sanitizing tool."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
- {file = "bleach-5.0.1-py3-none-any.whl", hash = "sha256:085f7f33c15bd408dd9b17a4ad77c577db66d76203e5984b1bd59baeee948b2a"},
- {file = "bleach-5.0.1.tar.gz", hash = "sha256:0d03255c47eb9bd2f26aa9bb7f2107732e7e8fe195ca2f64709fcf3b0a4a085c"},
+ {file = "bleach-6.0.0-py3-none-any.whl", hash = "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4"},
+ {file = "bleach-6.0.0.tar.gz", hash = "sha256:1a1a85c1595e07d8db14c5f09f09e6433502c51c595970edc090551f0db99414"},
]
[package.dependencies]
@@ -143,18 +143,17 @@ webencodings = "*"
[package.extras]
css = ["tinycss2 (>=1.1.0,<1.2)"]
-dev = ["Sphinx (==4.3.2)", "black (==22.3.0)", "build (==0.8.0)", "flake8 (==4.0.1)", "hashin (==0.17.0)", "mypy (==0.961)", "pip-tools (==6.6.2)", "pytest (==7.1.2)", "tox (==3.25.0)", "twine (==4.0.1)", "wheel (==0.37.1)"]
[[package]]
name = "canonicaljson"
-version = "1.6.4"
+version = "1.6.5"
description = "Canonical JSON"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
- {file = "canonicaljson-1.6.4-py3-none-any.whl", hash = "sha256:55d282853b4245dbcd953fe54c39b91571813d7c44e1dbf66e3c4f97ff134a48"},
- {file = "canonicaljson-1.6.4.tar.gz", hash = "sha256:6c09b2119511f30eb1126cfcd973a10824e20f1cfd25039cde3d1218dd9c8d8f"},
+ {file = "canonicaljson-1.6.5-py3-none-any.whl", hash = "sha256:806ea6f2cbb7405d20259e1c36dd1214ba5c242fa9165f5bd0bf2081f82c23fb"},
+ {file = "canonicaljson-1.6.5.tar.gz", hash = "sha256:68dfc157b011e07d94bf74b5d4ccc01958584ed942d9dfd5fdd706609e81cd4b"},
]
[package.dependencies]
@@ -339,50 +338,49 @@ files = [
[[package]]
name = "cryptography"
-version = "38.0.4"
+version = "39.0.1"
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
category = "main"
optional = false
python-versions = ">=3.6"
files = [
- {file = "cryptography-38.0.4-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:2fa36a7b2cc0998a3a4d5af26ccb6273f3df133d61da2ba13b3286261e7efb70"},
- {file = "cryptography-38.0.4-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:1f13ddda26a04c06eb57119caf27a524ccae20533729f4b1e4a69b54e07035eb"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:2ec2a8714dd005949d4019195d72abed84198d877112abb5a27740e217e0ea8d"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50a1494ed0c3f5b4d07650a68cd6ca62efe8b596ce743a5c94403e6f11bf06c1"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a10498349d4c8eab7357a8f9aa3463791292845b79597ad1b98a543686fb1ec8"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:10652dd7282de17990b88679cb82f832752c4e8237f0c714be518044269415db"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:bfe6472507986613dc6cc00b3d492b2f7564b02b3b3682d25ca7f40fa3fd321b"},
- {file = "cryptography-38.0.4-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:ce127dd0a6a0811c251a6cddd014d292728484e530d80e872ad9806cfb1c5b3c"},
- {file = "cryptography-38.0.4-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:53049f3379ef05182864d13bb9686657659407148f901f3f1eee57a733fb4b00"},
- {file = "cryptography-38.0.4-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:8a4b2bdb68a447fadebfd7d24855758fe2d6fecc7fed0b78d190b1af39a8e3b0"},
- {file = "cryptography-38.0.4-cp36-abi3-win32.whl", hash = "sha256:1d7e632804a248103b60b16fb145e8df0bc60eed790ece0d12efe8cd3f3e7744"},
- {file = "cryptography-38.0.4-cp36-abi3-win_amd64.whl", hash = "sha256:8e45653fb97eb2f20b8c96f9cd2b3a0654d742b47d638cf2897afbd97f80fa6d"},
- {file = "cryptography-38.0.4-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca57eb3ddaccd1112c18fc80abe41db443cc2e9dcb1917078e02dfa010a4f353"},
- {file = "cryptography-38.0.4-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:c9e0d79ee4c56d841bd4ac6e7697c8ff3c8d6da67379057f29e66acffcd1e9a7"},
- {file = "cryptography-38.0.4-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:0e70da4bdff7601b0ef48e6348339e490ebfb0cbe638e083c9c41fb49f00c8bd"},
- {file = "cryptography-38.0.4-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:998cd19189d8a747b226d24c0207fdaa1e6658a1d3f2494541cb9dfbf7dcb6d2"},
- {file = "cryptography-38.0.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67461b5ebca2e4c2ab991733f8ab637a7265bb582f07c7c88914b5afb88cb95b"},
- {file = "cryptography-38.0.4-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:4eb85075437f0b1fd8cd66c688469a0c4119e0ba855e3fef86691971b887caf6"},
- {file = "cryptography-38.0.4-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3178d46f363d4549b9a76264f41c6948752183b3f587666aff0555ac50fd7876"},
- {file = "cryptography-38.0.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:6391e59ebe7c62d9902c24a4d8bcbc79a68e7c4ab65863536127c8a9cd94043b"},
- {file = "cryptography-38.0.4-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:78e47e28ddc4ace41dd38c42e6feecfdadf9c3be2af389abbfeef1ff06822285"},
- {file = "cryptography-38.0.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fb481682873035600b5502f0015b664abc26466153fab5c6bc92c1ea69d478b"},
- {file = "cryptography-38.0.4-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:4367da5705922cf7070462e964f66e4ac24162e22ab0a2e9d31f1b270dd78083"},
- {file = "cryptography-38.0.4-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b4cad0cea995af760f82820ab4ca54e5471fc782f70a007f31531957f43e9dee"},
- {file = "cryptography-38.0.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:80ca53981ceeb3241998443c4964a387771588c4e4a5d92735a493af868294f9"},
- {file = "cryptography-38.0.4.tar.gz", hash = "sha256:175c1a818b87c9ac80bb7377f5520b7f31b3ef2a0004e2420319beadedb67290"},
+ {file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:6687ef6d0a6497e2b58e7c5b852b53f62142cfa7cd1555795758934da363a965"},
+ {file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:706843b48f9a3f9b9911979761c91541e3d90db1ca905fd63fee540a217698bc"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:5d2d8b87a490bfcd407ed9d49093793d0f75198a35e6eb1a923ce1ee86c62b41"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e17b26de248c33f3acffb922748151d71827d6021d98c70e6c1a25ddd78505"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e124352fd3db36a9d4a21c1aa27fd5d051e621845cb87fb851c08f4f75ce8be6"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:5aa67414fcdfa22cf052e640cb5ddc461924a045cacf325cd164e65312d99502"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:35f7c7d015d474f4011e859e93e789c87d21f6f4880ebdc29896a60403328f1f"},
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f24077a3b5298a5a06a8e0536e3ea9ec60e4c7ac486755e5fb6e6ea9b3500106"},
+ {file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:f0c64d1bd842ca2633e74a1a28033d139368ad959872533b1bab8c80e8240a0c"},
+ {file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:0f8da300b5c8af9f98111ffd512910bc792b4c77392a9523624680f7956a99d4"},
+ {file = "cryptography-39.0.1-cp36-abi3-win32.whl", hash = "sha256:fe913f20024eb2cb2f323e42a64bdf2911bb9738a15dba7d3cce48151034e3a8"},
+ {file = "cryptography-39.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:ced4e447ae29ca194449a3f1ce132ded8fcab06971ef5f618605aacaa612beac"},
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:807ce09d4434881ca3a7594733669bd834f5b2c6d5c7e36f8c00f691887042ad"},
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c5caeb8188c24888c90b5108a441c106f7faa4c4c075a2bcae438c6e8ca73cef"},
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4789d1e3e257965e960232345002262ede4d094d1a19f4d3b52e48d4d8f3b885"},
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:96f1157a7c08b5b189b16b47bc9db2332269d6680a196341bf30046330d15388"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e422abdec8b5fa8462aa016786680720d78bdce7a30c652b7fadf83a4ba35336"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:b0afd054cd42f3d213bf82c629efb1ee5f22eba35bf0eec88ea9ea7304f511a2"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:6f8ba7f0328b79f08bdacc3e4e66fb4d7aab0c3584e0bd41328dce5262e26b2e"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:ef8b72fa70b348724ff1218267e7f7375b8de4e8194d1636ee60510aae104cd0"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:aec5a6c9864be7df2240c382740fcf3b96928c46604eaa7f3091f58b878c0bb6"},
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdd188c8a6ef8769f148f88f859884507b954cc64db6b52f66ef199bb9ad660a"},
+ {file = "cryptography-39.0.1.tar.gz", hash = "sha256:d1f6198ee6d9148405e49887803907fe8962a23e6c6f83ea7d98f1c0de375695"},
]
[package.dependencies]
cffi = ">=1.12"
[package.extras]
-docs = ["sphinx (>=1.6.5,!=1.8.0,!=3.1.0,!=3.1.1)", "sphinx-rtd-theme"]
+docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
-pep8test = ["black", "flake8", "flake8-import-order", "pep8-naming"]
+pep8test = ["black", "check-manifest", "mypy", "ruff", "types-pytz", "types-requests"]
sdist = ["setuptools-rust (>=0.11.4)"]
ssh = ["bcrypt (>=3.1.5)"]
-test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-subtests", "pytest-xdist", "pytz"]
+test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-shard (>=0.1.2)", "pytest-subtests", "pytest-xdist", "pytz"]
+test-randomorder = ["pytest-randomly"]
+tox = ["tox"]
[[package]]
name = "defusedxml"
@@ -1148,36 +1146,38 @@ files = [
[[package]]
name = "mypy"
-version = "0.981"
+version = "1.0.0"
description = "Optional static typing for Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
- {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"},
- {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"},
- {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"},
- {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"},
- {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"},
- {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"},
- {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"},
- {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"},
- {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"},
- {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"},
- {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"},
- {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"},
- {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"},
- {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"},
- {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"},
- {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"},
- {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"},
- {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"},
- {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"},
- {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"},
- {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"},
- {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"},
- {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"},
- {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"},
+ {file = "mypy-1.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0626db16705ab9f7fa6c249c017c887baf20738ce7f9129da162bb3075fc1af"},
+ {file = "mypy-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ace23f6bb4aec4604b86c4843276e8fa548d667dbbd0cb83a3ae14b18b2db6c"},
+ {file = "mypy-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87edfaf344c9401942883fad030909116aa77b0fa7e6e8e1c5407e14549afe9a"},
+ {file = "mypy-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0ab090d9240d6b4e99e1fa998c2d0aa5b29fc0fb06bd30e7ad6183c95fa07593"},
+ {file = "mypy-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:7cc2c01dfc5a3cbddfa6c13f530ef3b95292f926329929001d45e124342cd6b7"},
+ {file = "mypy-1.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14d776869a3e6c89c17eb943100f7868f677703c8a4e00b3803918f86aafbc52"},
+ {file = "mypy-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb2782a036d9eb6b5a6efcdda0986774bf798beef86a62da86cb73e2a10b423d"},
+ {file = "mypy-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cfca124f0ac6707747544c127880893ad72a656e136adc935c8600740b21ff5"},
+ {file = "mypy-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8845125d0b7c57838a10fd8925b0f5f709d0e08568ce587cc862aacce453e3dd"},
+ {file = "mypy-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b1b9e1ed40544ef486fa8ac022232ccc57109f379611633ede8e71630d07d2"},
+ {file = "mypy-1.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c7cf862aef988b5fbaa17764ad1d21b4831436701c7d2b653156a9497d92c83c"},
+ {file = "mypy-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd187d92b6939617f1168a4fe68f68add749902c010e66fe574c165c742ed88"},
+ {file = "mypy-1.0.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4e5175026618c178dfba6188228b845b64131034ab3ba52acaffa8f6c361f805"},
+ {file = "mypy-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f6ac8c87e046dc18c7d1d7f6653a66787a4555085b056fe2d599f1f1a2a2d21"},
+ {file = "mypy-1.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7306edca1c6f1b5fa0bc9aa645e6ac8393014fa82d0fa180d0ebc990ebe15964"},
+ {file = "mypy-1.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3cfad08f16a9c6611e6143485a93de0e1e13f48cfb90bcad7d5fde1c0cec3d36"},
+ {file = "mypy-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67cced7f15654710386e5c10b96608f1ee3d5c94ca1da5a2aad5889793a824c1"},
+ {file = "mypy-1.0.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a86b794e8a56ada65c573183756eac8ac5b8d3d59daf9d5ebd72ecdbb7867a43"},
+ {file = "mypy-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:50979d5efff8d4135d9db293c6cb2c42260e70fb010cbc697b1311a4d7a39ddb"},
+ {file = "mypy-1.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ae4c7a99e5153496243146a3baf33b9beff714464ca386b5f62daad601d87af"},
+ {file = "mypy-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e398652d005a198a7f3c132426b33c6b85d98aa7dc852137a2a3be8890c4072"},
+ {file = "mypy-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be78077064d016bc1b639c2cbcc5be945b47b4261a4f4b7d8923f6c69c5c9457"},
+ {file = "mypy-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92024447a339400ea00ac228369cd242e988dd775640755fa4ac0c126e49bb74"},
+ {file = "mypy-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe523fcbd52c05040c7bee370d66fee8373c5972171e4fbc323153433198592d"},
+ {file = "mypy-1.0.0-py3-none-any.whl", hash = "sha256:2efa963bdddb27cb4a0d42545cd137a8d2b883bd181bbc4525b568ef6eca258f"},
+ {file = "mypy-1.0.0.tar.gz", hash = "sha256:f34495079c8d9da05b183f9f7daec2878280c2ad7cc81da686ef0b484cea2ecf"},
]
[package.dependencies]
@@ -1188,6 +1188,7 @@ typing-extensions = ">=3.10"
[package.extras]
dmypy = ["psutil (>=4.0)"]
+install-types = ["pip"]
python2 = ["typed-ast (>=1.4.0,<2)"]
reports = ["lxml"]
@@ -1205,18 +1206,18 @@ files = [
[[package]]
name = "mypy-zope"
-version = "0.3.11"
+version = "0.9.0"
description = "Plugin for mypy to support zope interfaces"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"},
- {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"},
+ {file = "mypy-zope-0.9.0.tar.gz", hash = "sha256:88bf6cd056e38b338e6956055958a7805b4ff84404ccd99e29883a3647a1aeb3"},
+ {file = "mypy_zope-0.9.0-py3-none-any.whl", hash = "sha256:e1bb4b57084f76ff8a154a3e07880a1af2ac6536c491dad4b143d529f72c5d15"},
]
[package.dependencies]
-mypy = "0.981"
+mypy = "1.0.0"
"zope.interface" = "*"
"zope.schema" = "*"
@@ -1970,28 +1971,28 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
[[package]]
name = "ruff"
-version = "0.0.230"
+version = "0.0.237"
description = "An extremely fast Python linter, written in Rust."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.0.230-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:fcc31d02cebda0a85e2e13a44642aea7f84362cb4f589e2f6b864e3928e4a7db"},
- {file = "ruff-0.0.230-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:45a7f2c7155d520b8ca255a01235763d5c25fd5e7af055e50a78c6d91ece0ced"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4eca8b185ab56cac67acc23287c3c8c62a0c0ffadc0787a3bef3a6e77eaed82f"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec2bcdb5040efd8082a3a98369eec4bdc5fd05f53cc6714cb2b725d557d4abe8"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:26571aee2b93b60e47e44478f72a9787b387f752e85b85f176739bd91b27cfd1"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4b69c9883c3e264f8bb2d52bdabb88b8d9672750ea05f33e0ff52532824bd5c5"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b3dc88b83f200378a9b9c91036989f0285a10759514c42235ce02e5824ac8d0"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:767716f008dd3a40ec2318396f648fda437c6968087a4526cde5879e382cf477"},
- {file = "ruff-0.0.230-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac27a0f9b96d9923cef7d911790a21a19b51aec0f08375ccc47ad735b1054d78"},
- {file = "ruff-0.0.230-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:729dfc7b7ad4f7d8761dc60c58f15372d6f5c2dd9b6c5952524f2bc3aec7de6a"},
- {file = "ruff-0.0.230-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ad086cf2e5fef274687121f673f0f9b60c8981ec07c2bb0448c459cbaef81bcb"},
- {file = "ruff-0.0.230-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4feaed0978c24687133cd11c7380de20aa841f893e24430c735cc6c3faba4837"},
- {file = "ruff-0.0.230-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1d1046d0d43a0f24b2e9e61d76bb201b486ad02e9787d3432af43bd7d16f2c2e"},
- {file = "ruff-0.0.230-py3-none-win32.whl", hash = "sha256:4d627911c9ba57bcd2f2776f1c09a10d334db163cb5be8c892e7ec7b59ccf58c"},
- {file = "ruff-0.0.230-py3-none-win_amd64.whl", hash = "sha256:27fd4891a1d0642f5b2038ebf86f8169bc3d466964bdfaa0ce2a65149bc7cced"},
- {file = "ruff-0.0.230.tar.gz", hash = "sha256:a049f93af1057ac450e8c09559d44e371eda1c151b1b863c0013a1066fefddb0"},
+ {file = "ruff-0.0.237-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:2ea04d826ffca58a7ae926115a801960c757d53c9027f2ca9acbe84c9f2b2f04"},
+ {file = "ruff-0.0.237-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8ed113937fab9f73f8c1a6c0350bb4fe03e951370139c6e0adb81f48a8dcf4c6"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9bcb71a3efb5fe886eb48d739cfae5df4a15617e7b5a7668aa45ebf74c0d3fa"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:80ce10718abbf502818c0d650ebab99fdcef5e937a1ded3884493ddff804373c"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cc6cb7c1efcc260df5a939435649610a28f9f438b8b313384c8985ac6574f9f"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7eef0c7a1e45a4e30328ae101613575944cbf47a3a11494bf9827722da6c66b3"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0d122433a21ce4a21fbba34b73fc3add0ccddd1643b3ff5abb8d2767952f872e"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b76311335adda4de3c1d471e64e89a49abfeebf02647e3db064e7740e7f36ed6"},
+ {file = "ruff-0.0.237-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c5977b643aaf2b6f84641265f835b6c7f67fcca38dbae08c4f15602e084ca0"},
+ {file = "ruff-0.0.237-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3d6ed86d0d4d742360a262d52191581f12b669a68e59ae3b52e80d7483b3d7b3"},
+ {file = "ruff-0.0.237-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fedfb60f986c26cdb1809db02866e68508db99910c587d2c4066a5c07aa85593"},
+ {file = "ruff-0.0.237-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bb96796be5919871fa9ae7e88968ba9e14306d9a3f217ca6c204f68a5abeccdd"},
+ {file = "ruff-0.0.237-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea239cfedf67b74ea4952e1074bb99a4281c2145441d70bc7e2f058d5c49f1c9"},
+ {file = "ruff-0.0.237-py3-none-win32.whl", hash = "sha256:8d6a1d21ae15da2b1dcffeee2606e90de0e6717e72957da7d16ab6ae18dd0058"},
+ {file = "ruff-0.0.237-py3-none-win_amd64.whl", hash = "sha256:525e5ec81cee29b993f77976026a6bf44528a14aa6edb1ef47bd8079147395ae"},
+ {file = "ruff-0.0.237.tar.gz", hash = "sha256:630c575f543733adf6c19a11d9a02ca9ecc364bd7140af8a4c854d4728be6b56"},
]
[[package]]
@@ -2028,14 +2029,14 @@ doc = ["Sphinx", "sphinx-rtd-theme"]
[[package]]
name = "sentry-sdk"
-version = "1.13.0"
+version = "1.15.0"
description = "Python client for Sentry (https://sentry.io)"
category = "main"
optional = true
python-versions = "*"
files = [
- {file = "sentry-sdk-1.13.0.tar.gz", hash = "sha256:72da0766c3069a3941eadbdfa0996f83f5a33e55902a19ba399557cfee1dddcc"},
- {file = "sentry_sdk-1.13.0-py2.py3-none-any.whl", hash = "sha256:b7ff6318183e551145b5c4766eb65b59ad5b63ff234dffddc5fb50340cad6729"},
+ {file = "sentry-sdk-1.15.0.tar.gz", hash = "sha256:69ecbb2e1ff4db02a06c4f20f6f69cb5dfe3ebfbc06d023e40d77cf78e9c37e7"},
+ {file = "sentry_sdk-1.15.0-py2.py3-none-any.whl", hash = "sha256:7ad4d37dd093f4a7cb5ad804c6efe9e8fab8873f7ffc06042dc3f3fd700a93ec"},
]
[package.dependencies]
@@ -2053,7 +2054,8 @@ falcon = ["falcon (>=1.4)"]
fastapi = ["fastapi (>=0.79.0)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)"]
httpx = ["httpx (>=0.16.0)"]
-opentelemetry = ["opentelemetry-distro (>=0.350b0)"]
+huey = ["huey (>=2)"]
+opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
@@ -2255,13 +2257,13 @@ files = [
[[package]]
name = "systemd-python"
-version = "234"
+version = "235"
description = "Python interface for libsystemd"
category = "main"
optional = true
python-versions = "*"
files = [
- {file = "systemd-python-234.tar.gz", hash = "sha256:fd0e44bf70eadae45aadc292cb0a7eb5b0b6372cd1b391228047d33895db83e7"},
+ {file = "systemd-python-235.tar.gz", hash = "sha256:4e57f39797fd5d9e2d22b8806a252d7c0106c936039d1e71c8c6b8008e695c0a"},
]
[[package]]
@@ -2546,14 +2548,14 @@ files = [
[[package]]
name = "types-bleach"
-version = "5.0.3.1"
+version = "6.0.0.0"
description = "Typing stubs for bleach"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "types-bleach-5.0.3.1.tar.gz", hash = "sha256:ce8772ea5126dab1883851b41e3aeff229aa5213ced36096990344e632e92373"},
- {file = "types_bleach-5.0.3.1-py3-none-any.whl", hash = "sha256:af5f1b3a54ff279f54c29eccb2e6988ebb6718bc4061469588a5fd4880a79287"},
+ {file = "types-bleach-6.0.0.0.tar.gz", hash = "sha256:770ce9c7ea6173743ef1a4a70f2619bb1819bf53c7cd0336d939af93f488fbe2"},
+ {file = "types_bleach-6.0.0.0-py3-none-any.whl", hash = "sha256:75f55f035837c5fce2cd0bd5162a2a90057680a89c9275588a5c12f5f597a14a"},
]
[[package]]
@@ -2622,14 +2624,14 @@ files = [
[[package]]
name = "types-jsonschema"
-version = "4.17.0.3"
+version = "4.17.0.5"
description = "Typing stubs for jsonschema"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "types-jsonschema-4.17.0.3.tar.gz", hash = "sha256:746aa466ffed9a1acc7bdbd0ac0b5e068f00be2ee008c1d1e14b0944a8c8b24b"},
- {file = "types_jsonschema-4.17.0.3-py3-none-any.whl", hash = "sha256:c8d5b26b7c8da6a48d7fb1ce029b97e0ff6e74db3727efb968c69f39ad013685"},
+ {file = "types-jsonschema-4.17.0.5.tar.gz", hash = "sha256:7adc7bfca4afe291de0c93eca9367aa72a4fbe8ce87fe15642c600ad97d45dd6"},
+ {file = "types_jsonschema-4.17.0.5-py3-none-any.whl", hash = "sha256:79ac8a7763fe728947af90a24168b91621edf7e8425bf3670abd4ea0d4758fba"},
]
[[package]]
@@ -2646,14 +2648,14 @@ files = [
[[package]]
name = "types-pillow"
-version = "9.4.0.5"
+version = "9.4.0.13"
description = "Typing stubs for Pillow"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "types-Pillow-9.4.0.5.tar.gz", hash = "sha256:941cefaac2f5297d7d2a9989633c95b4063112690dc21c965d46bd5a7fff3c76"},
- {file = "types_Pillow-9.4.0.5-py3-none-any.whl", hash = "sha256:a1d2b3e070b4d852af04f76f018d12bd51abb4abca3b725d91b35e01cda7a2de"},
+ {file = "types-Pillow-9.4.0.13.tar.gz", hash = "sha256:4510aa98a28947bf63f2b29edebbd11b7cff8647d90b867cec9b3674c0a8c321"},
+ {file = "types_Pillow-9.4.0.13-py3-none-any.whl", hash = "sha256:14a8a19021b8fe569a9fef9edc64a8d8a4aef340e38669d4fb3dc05cfd941130"},
]
[[package]]
@@ -2697,14 +2699,14 @@ files = [
[[package]]
name = "types-requests"
-version = "2.28.11.8"
+version = "2.28.11.12"
description = "Typing stubs for requests"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "types-requests-2.28.11.8.tar.gz", hash = "sha256:e67424525f84adfbeab7268a159d3c633862dafae15c5b19547ce1b55954f0a3"},
- {file = "types_requests-2.28.11.8-py3-none-any.whl", hash = "sha256:61960554baca0008ae7e2db2bd3b322ca9a144d3e80ce270f5fb640817e40994"},
+ {file = "types-requests-2.28.11.12.tar.gz", hash = "sha256:fd530aab3fc4f05ee36406af168f0836e6f00f1ee51a0b96b7311f82cb675230"},
+ {file = "types_requests-2.28.11.12-py3-none-any.whl", hash = "sha256:dbc2933635860e553ffc59f5e264264981358baffe6342b925e3eb8261f866ee"},
]
[package.dependencies]
@@ -2712,14 +2714,14 @@ types-urllib3 = "<1.27"
[[package]]
name = "types-setuptools"
-version = "67.1.0.0"
+version = "67.3.0.1"
description = "Typing stubs for setuptools"
category = "dev"
optional = false
python-versions = "*"
files = [
- {file = "types-setuptools-67.1.0.0.tar.gz", hash = "sha256:162a39d22e3a5eb802197c84f16b19e798101bbd33d9437837fbb45627da5627"},
- {file = "types_setuptools-67.1.0.0-py3-none-any.whl", hash = "sha256:5bd7a10d93e468bfcb10d24cb8ea5e12ac4f4ac91267293959001f1448cf0619"},
+ {file = "types-setuptools-67.3.0.1.tar.gz", hash = "sha256:1a26d373036c720e566823b6edd664a2db4d138b6eeba856721ec1254203474f"},
+ {file = "types_setuptools-67.3.0.1-py3-none-any.whl", hash = "sha256:a7e0f0816b5b449f5bcdc0efa43da91ff81dbe6941f293a6490d68a450e130a1"},
]
[package.dependencies]
@@ -3028,4 +3030,4 @@ user-search = ["pyicu"]
[metadata]
lock-version = "2.0"
python-versions = "^3.7.1"
-content-hash = "2673ef0530a42dae1df998bacfcaf88a563529b39461003a980743a97f02996f"
+content-hash = "e12077711e5ff83f3c6038ea44c37bd49773799ec8245035b01094b7800c5c92"
diff --git a/pyproject.toml b/pyproject.toml
index 921a1fccbc..cef7d295c1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
-version = "1.77.0rc2"
+version = "1.78.0rc1"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors "]
license = "Apache-2.0"
@@ -154,7 +154,9 @@ python = "^3.7.1"
# we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
jsonschema = ">=3.0.0"
# frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41
-frozendict = ">=1,!=2.1.2"
+# We cannot test our wheels against the 2.3.5 release in CI. Putting in an upper bound for this
+# because frozendict has been more trouble than it's worth; we would like to move to immutabledict.
+frozendict = ">=1,!=2.1.2,<2.3.5"
# We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0
unpaddedbase64 = ">=2.1.0"
# We require 1.5.0 to work around an issue when running against the C implementation of
@@ -311,7 +313,7 @@ all = [
# We pin black so that our tests don't start failing on new releases.
isort = ">=5.10.1"
black = ">=22.3.0"
-ruff = "0.0.230"
+ruff = "0.0.237"
# Typechecking
mypy = "*"
@@ -346,6 +348,9 @@ twine = "*"
# Towncrier min version comes from #3425. Rationale unclear.
towncrier = ">=18.6.0rc1"
+# Used for checking the Poetry lockfile
+tomli = ">=1.2.3"
+
[build-system]
# The upper bounds here are defensive, intended to prevent situations like
# #13849 and #14079 where we see buildtime or runtime errors caused by build
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 09e2bba5e5..533a8cc677 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -24,7 +24,7 @@ anyhow = "1.0.63"
lazy_static = "1.4.0"
log = "0.4.17"
pyo3 = { version = "0.17.1", features = ["macros", "anyhow", "abi3", "abi3-py37"] }
-pyo3-log = "0.7.0"
+pyo3-log = "0.8.1"
pythonize = "0.17.0"
regex = "1.6.0"
serde = { version = "1.0.144", features = ["derive"] }
diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs
index 35f7a50bce..efd19a2165 100644
--- a/rust/benches/evaluator.rs
+++ b/rust/benches/evaluator.rs
@@ -15,7 +15,8 @@
#![feature(test)]
use std::collections::BTreeSet;
use synapse::push::{
- evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
+ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue,
+ PushRules, SimpleJsonValue,
};
use test::Bencher;
@@ -24,9 +25,18 @@ extern crate test;
#[bench]
fn bench_match_exact(b: &mut Bencher) {
let flattened_keys = [
- ("type".to_string(), "m.text".to_string()),
- ("room_id".to_string(), "!room:server".to_string()),
- ("content.body".to_string(), "test message".to_string()),
+ (
+ "type".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
+ ),
+ (
+ "room_id".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
+ ),
+ (
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
+ ),
]
.into_iter()
.collect();
@@ -35,7 +45,6 @@ fn bench_match_exact(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
Default::default(),
@@ -43,6 +52,8 @@ fn bench_match_exact(b: &mut Bencher) {
true,
vec![],
false,
+ false,
+ false,
)
.unwrap();
@@ -63,9 +74,18 @@ fn bench_match_exact(b: &mut Bencher) {
#[bench]
fn bench_match_word(b: &mut Bencher) {
let flattened_keys = [
- ("type".to_string(), "m.text".to_string()),
- ("room_id".to_string(), "!room:server".to_string()),
- ("content.body".to_string(), "test message".to_string()),
+ (
+ "type".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
+ ),
+ (
+ "room_id".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
+ ),
+ (
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
+ ),
]
.into_iter()
.collect();
@@ -74,7 +94,6 @@ fn bench_match_word(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
Default::default(),
@@ -82,6 +101,8 @@ fn bench_match_word(b: &mut Bencher) {
true,
vec![],
false,
+ false,
+ false,
)
.unwrap();
@@ -102,9 +123,18 @@ fn bench_match_word(b: &mut Bencher) {
#[bench]
fn bench_match_word_miss(b: &mut Bencher) {
let flattened_keys = [
- ("type".to_string(), "m.text".to_string()),
- ("room_id".to_string(), "!room:server".to_string()),
- ("content.body".to_string(), "test message".to_string()),
+ (
+ "type".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
+ ),
+ (
+ "room_id".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
+ ),
+ (
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
+ ),
]
.into_iter()
.collect();
@@ -113,7 +143,6 @@ fn bench_match_word_miss(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
Default::default(),
@@ -121,6 +150,8 @@ fn bench_match_word_miss(b: &mut Bencher) {
true,
vec![],
false,
+ false,
+ false,
)
.unwrap();
@@ -141,9 +172,18 @@ fn bench_match_word_miss(b: &mut Bencher) {
#[bench]
fn bench_eval_message(b: &mut Bencher) {
let flattened_keys = [
- ("type".to_string(), "m.text".to_string()),
- ("room_id".to_string(), "!room:server".to_string()),
- ("content.body".to_string(), "test message".to_string()),
+ (
+ "type".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
+ ),
+ (
+ "room_id".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
+ ),
+ (
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
+ ),
]
.into_iter()
.collect();
@@ -152,7 +192,6 @@ fn bench_eval_message(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
Default::default(),
@@ -160,6 +199,8 @@ fn bench_eval_message(b: &mut Bencher) {
true,
vec![],
false,
+ false,
+ false,
)
.unwrap();
diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs
index 97d0a0a7e2..4a62b9696f 100644
--- a/rust/src/push/base_rules.rs
+++ b/rust/src/push/base_rules.rs
@@ -21,13 +21,13 @@ use lazy_static::lazy_static;
use serde_json::Value;
use super::KnownCondition;
-use crate::push::Action;
use crate::push::Condition;
use crate::push::EventMatchCondition;
use crate::push::PushRule;
use crate::push::RelatedEventMatchCondition;
use crate::push::SetTweak;
use crate::push::TweakValue;
+use crate::push::{Action, ExactEventMatchCondition, SimpleJsonValue};
const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("highlight"),
@@ -168,7 +168,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"),
priority_class: 5,
conditions: Cow::Borrowed(&[
- Condition::Known(KnownCondition::IsRoomMention),
+ Condition::Known(KnownCondition::ExactEventMatch(ExactEventMatchCondition {
+ key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"),
+ value: Cow::Borrowed(&SimpleJsonValue::Bool(true)),
+ })),
Condition::Known(KnownCondition::SenderNotificationPermission {
key: Cow::Borrowed("room"),
}),
@@ -223,7 +226,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
pattern_type: None,
},
))]),
- actions: Cow::Borrowed(&[Action::DontNotify]),
+ actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
index ec7a8c4453..55551ecb56 100644
--- a/rust/src/push/evaluator.rs
+++ b/rust/src/push/evaluator.rs
@@ -14,6 +14,7 @@
use std::collections::{BTreeMap, BTreeSet};
+use crate::push::JsonValue;
use anyhow::{Context, Error};
use lazy_static::lazy_static;
use log::warn;
@@ -22,8 +23,8 @@ use regex::Regex;
use super::{
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
- Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
- RelatedEventMatchCondition,
+ Action, Condition, EventMatchCondition, ExactEventMatchCondition, FilteredPushRules,
+ KnownCondition, RelatedEventMatchCondition, SimpleJsonValue,
};
lazy_static! {
@@ -61,9 +62,9 @@ impl RoomVersionFeatures {
/// Allows running a set of push rules against a particular event.
#[pyclass]
pub struct PushRuleEvaluator {
- /// A mapping of "flattened" keys to string values in the event, e.g.
+ /// A mapping of "flattened" keys to simple JSON values in the event, e.g.
/// includes things like "type" and "content.msgtype".
- flattened_keys: BTreeMap,
+ flattened_keys: BTreeMap,
/// The "content.body", if any.
body: String,
@@ -72,8 +73,6 @@ pub struct PushRuleEvaluator {
has_mentions: bool,
/// The user mentions that were part of the message.
user_mentions: BTreeSet,
- /// True if the message is a room message.
- room_mention: bool,
/// The number of users in the room.
room_member_count: u64,
@@ -87,7 +86,7 @@ pub struct PushRuleEvaluator {
/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
- related_events_flattened: BTreeMap>,
+ related_events_flattened: BTreeMap>,
/// If msc3664, push rules for related events, is enabled.
related_event_match_enabled: bool,
@@ -98,6 +97,12 @@ pub struct PushRuleEvaluator {
/// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,
+
+ /// If MSC3758 (exact_event_match push rule condition) is enabled.
+ msc3758_exact_event_match: bool,
+
+ /// If MSC3966 (exact_event_property_contains push rule condition) is enabled.
+ msc3966_exact_event_property_contains: bool,
}
#[pymethods]
@@ -106,29 +111,29 @@ impl PushRuleEvaluator {
#[allow(clippy::too_many_arguments)]
#[new]
pub fn py_new(
- flattened_keys: BTreeMap,
+ flattened_keys: BTreeMap,
has_mentions: bool,
user_mentions: BTreeSet,
- room_mention: bool,
room_member_count: u64,
sender_power_level: Option,
notification_power_levels: BTreeMap,
- related_events_flattened: BTreeMap>,
+ related_events_flattened: BTreeMap>,
related_event_match_enabled: bool,
room_version_feature_flags: Vec,
msc3931_enabled: bool,
+ msc3758_exact_event_match: bool,
+ msc3966_exact_event_property_contains: bool,
) -> Result {
- let body = flattened_keys
- .get("content.body")
- .cloned()
- .unwrap_or_default();
+ let body = match flattened_keys.get("content.body") {
+ Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(),
+ _ => String::new(),
+ };
Ok(PushRuleEvaluator {
flattened_keys,
body,
has_mentions,
user_mentions,
- room_mention,
room_member_count,
notification_power_levels,
sender_power_level,
@@ -136,6 +141,8 @@ impl PushRuleEvaluator {
related_event_match_enabled,
room_version_feature_flags,
msc3931_enabled,
+ msc3758_exact_event_match,
+ msc3966_exact_event_property_contains,
})
}
@@ -252,9 +259,15 @@ impl PushRuleEvaluator {
KnownCondition::EventMatch(event_match) => {
self.match_event_match(event_match, user_id)?
}
+ KnownCondition::ExactEventMatch(exact_event_match) => {
+ self.match_exact_event_match(exact_event_match)?
+ }
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
+ KnownCondition::ExactEventPropertyContains(exact_event_match) => {
+ self.match_exact_event_property_contains(exact_event_match)?
+ }
KnownCondition::IsUserMention => {
if let Some(uid) = user_id {
self.user_mentions.contains(uid)
@@ -262,7 +275,6 @@ impl PushRuleEvaluator {
false
}
}
- KnownCondition::IsRoomMention => self.room_mention,
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
@@ -337,7 +349,9 @@ impl PushRuleEvaluator {
return Ok(false);
};
- let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) {
+ let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) =
+ self.flattened_keys.get(&*event_match.key)
+ {
haystack
} else {
return Ok(false);
@@ -355,6 +369,29 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}
+ /// Evaluates a `exact_event_match` condition. (MSC3758)
+ fn match_exact_event_match(
+ &self,
+ exact_event_match: &ExactEventMatchCondition,
+ ) -> Result {
+ // First check if the feature is enabled.
+ if !self.msc3758_exact_event_match {
+ return Ok(false);
+ }
+
+ let value = &exact_event_match.value;
+
+ let haystack = if let Some(JsonValue::Value(haystack)) =
+ self.flattened_keys.get(&*exact_event_match.key)
+ {
+ haystack
+ } else {
+ return Ok(false);
+ };
+
+ Ok(haystack == &**value)
+ }
+
/// Evaluates a `related_event_match` condition. (MSC3664)
fn match_related_event_match(
&self,
@@ -410,11 +447,12 @@ impl PushRuleEvaluator {
return Ok(false);
};
- let haystack = if let Some(haystack) = event.get(&**key) {
- haystack
- } else {
- return Ok(false);
- };
+ let haystack =
+ if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) {
+ haystack
+ } else {
+ return Ok(false);
+ };
// For the content.body we match against "words", but for everything
// else we match against the entire value.
@@ -428,6 +466,29 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}
+ /// Evaluates a `exact_event_property_contains` condition. (MSC3758)
+ fn match_exact_event_property_contains(
+ &self,
+ exact_event_match: &ExactEventMatchCondition,
+ ) -> Result {
+ // First check if the feature is enabled.
+ if !self.msc3966_exact_event_property_contains {
+ return Ok(false);
+ }
+
+ let value = &exact_event_match.value;
+
+ let haystack = if let Some(JsonValue::Array(haystack)) =
+ self.flattened_keys.get(&*exact_event_match.key)
+ {
+ haystack
+ } else {
+ return Ok(false);
+ };
+
+ Ok(haystack.contains(&**value))
+ }
+
/// Match the member count against an 'is' condition
/// The `is` condition can be things like '>2', '==3' or even just '4'.
fn match_member_count(&self, is: &str) -> Result {
@@ -455,12 +516,14 @@ impl PushRuleEvaluator {
#[test]
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
- flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
+ flattened_keys.insert(
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
+ );
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
BTreeMap::new(),
@@ -468,6 +531,8 @@ fn push_rule_evaluator() {
true,
vec![],
true,
+ true,
+ true,
)
.unwrap();
@@ -482,13 +547,15 @@ fn test_requires_room_version_supports_condition() {
use crate::push::{PushRule, PushRules};
let mut flattened_keys = BTreeMap::new();
- flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
+ flattened_keys.insert(
+ "content.body".to_string(),
+ JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
+ );
let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
false,
BTreeSet::new(),
- false,
10,
Some(0),
BTreeMap::new(),
@@ -496,6 +563,8 @@ fn test_requires_room_version_supports_condition() {
false,
flags,
true,
+ true,
+ true,
)
.unwrap();
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 3c4f876cab..fdd2b2c143 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet};
use anyhow::{Context, Error};
use log::warn;
+use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
+use pyo3::types::{PyBool, PyList, PyLong, PyString};
use pythonize::{depythonize, pythonize};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
@@ -248,6 +250,65 @@ impl<'de> Deserialize<'de> for Action {
}
}
+/// A simple JSON values (string, int, boolean, or null).
+#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
+#[serde(untagged)]
+pub enum SimpleJsonValue {
+ Str(String),
+ Int(i64),
+ Bool(bool),
+ Null,
+}
+
+impl<'source> FromPyObject<'source> for SimpleJsonValue {
+ fn extract(ob: &'source PyAny) -> PyResult {
+ if let Ok(s) = ::try_from(ob) {
+ Ok(SimpleJsonValue::Str(s.to_string()))
+ // A bool *is* an int, ensure we try bool first.
+ } else if let Ok(b) = ::try_from(ob) {
+ Ok(SimpleJsonValue::Bool(b.extract()?))
+ } else if let Ok(i) = ::try_from(ob) {
+ Ok(SimpleJsonValue::Int(i.extract()?))
+ } else if ob.is_none() {
+ Ok(SimpleJsonValue::Null)
+ } else {
+ Err(PyTypeError::new_err(format!(
+ "Can't convert from {} to SimpleJsonValue",
+ ob.get_type().name()?
+ )))
+ }
+ }
+}
+
+/// A JSON values (list, string, int, boolean, or null).
+#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
+#[serde(untagged)]
+pub enum JsonValue {
+ Array(Vec),
+ Value(SimpleJsonValue),
+}
+
+impl<'source> FromPyObject<'source> for JsonValue {
+ fn extract(ob: &'source PyAny) -> PyResult {
+ if let Ok(l) = ::try_from(ob) {
+ match l.iter().map(SimpleJsonValue::extract).collect() {
+ Ok(a) => Ok(JsonValue::Array(a)),
+ Err(e) => Err(PyTypeError::new_err(format!(
+ "Can't convert to JsonValue::Array: {}",
+ e
+ ))),
+ }
+ } else if let Ok(v) = SimpleJsonValue::extract(ob) {
+ Ok(JsonValue::Value(v))
+ } else {
+ Err(PyTypeError::new_err(format!(
+ "Can't convert from {} to JsonValue",
+ ob.get_type().name()?
+ )))
+ }
+ }
+}
+
/// A condition used in push rules to match against an event.
///
/// We need this split as `serde` doesn't give us the ability to have a
@@ -267,12 +328,14 @@ pub enum Condition {
#[serde(tag = "kind")]
pub enum KnownCondition {
EventMatch(EventMatchCondition),
+ #[serde(rename = "com.beeper.msc3758.exact_event_match")]
+ ExactEventMatch(ExactEventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
+ #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")]
+ ExactEventPropertyContains(ExactEventMatchCondition),
#[serde(rename = "org.matrix.msc3952.is_user_mention")]
IsUserMention,
- #[serde(rename = "org.matrix.msc3952.is_room_mention")]
- IsRoomMention,
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
@@ -309,6 +372,13 @@ pub struct EventMatchCondition {
pub pattern_type: Option>,
}
+/// The body of a [`Condition::ExactEventMatch`]
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct ExactEventMatchCondition {
+ pub key: Cow<'static, str>,
+ pub value: Cow<'static, SimpleJsonValue>,
+}
+
/// The body of a [`Condition::RelatedEventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RelatedEventMatchCondition {
@@ -542,6 +612,48 @@ fn test_deserialize_unstable_msc3931_condition() {
));
}
+#[test]
+fn test_deserialize_unstable_msc3758_condition() {
+ // A string condition should work.
+ let json =
+ r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#;
+
+ let condition: Condition = serde_json::from_str(json).unwrap();
+ assert!(matches!(
+ condition,
+ Condition::Known(KnownCondition::ExactEventMatch(_))
+ ));
+
+ // A boolean condition should work.
+ let json =
+ r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#;
+
+ let condition: Condition = serde_json::from_str(json).unwrap();
+ assert!(matches!(
+ condition,
+ Condition::Known(KnownCondition::ExactEventMatch(_))
+ ));
+
+ // An integer condition should work.
+ let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#;
+
+ let condition: Condition = serde_json::from_str(json).unwrap();
+ assert!(matches!(
+ condition,
+ Condition::Known(KnownCondition::ExactEventMatch(_))
+ ));
+
+ // A null condition should work
+ let json =
+ r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#;
+
+ let condition: Condition = serde_json::from_str(json).unwrap();
+ assert!(matches!(
+ condition,
+ Condition::Known(KnownCondition::ExactEventMatch(_))
+ ));
+}
+
#[test]
fn test_deserialize_unstable_msc3952_user_condition() {
let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#;
@@ -553,17 +665,6 @@ fn test_deserialize_unstable_msc3952_user_condition() {
));
}
-#[test]
-fn test_deserialize_unstable_msc3952_room_condition() {
- let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#;
-
- let condition: Condition = serde_json::from_str(json).unwrap();
- assert!(matches!(
- condition,
- Condition::Known(KnownCondition::IsRoomMention)
- ));
-}
-
#[test]
fn test_deserialize_custom_condition() {
let json = r#"{"kind":"custom_tag"}"#;
diff --git a/scripts-dev/check_locked_deps_have_sdists.py b/scripts-dev/check_locked_deps_have_sdists.py
new file mode 100755
index 0000000000..63ad99280a
--- /dev/null
+++ b/scripts-dev/check_locked_deps_have_sdists.py
@@ -0,0 +1,58 @@
+#! /usr/bin/env python
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+from pathlib import Path
+from typing import Dict, List
+
+import tomli
+
+
+def main() -> None:
+ lockfile_path = Path(__file__).parent.parent.joinpath("poetry.lock")
+ with open(lockfile_path, "rb") as lockfile:
+ lockfile_content = tomli.load(lockfile)
+
+ # Poetry 1.3+ lockfile format:
+ # There's a `files` inline table in each [[package]]
+ packages_to_assets: Dict[str, List[Dict[str, str]]] = {
+ package["name"]: package["files"] for package in lockfile_content["package"]
+ }
+
+ success = True
+
+ for package_name, assets in packages_to_assets.items():
+ has_sdist = any(asset["file"].endswith(".tar.gz") for asset in assets)
+ if not has_sdist:
+ success = False
+ print(
+ f"Locked package {package_name!r} does not have a source distribution!",
+ file=sys.stderr,
+ )
+
+ if not success:
+ print(
+ "\nThere were some problems with the Poetry lockfile (poetry.lock).",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ print(
+ f"Poetry lockfile OK. {len(packages_to_assets)} locked packages checked.",
+ file=sys.stderr,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
index e2bc1640bb..473f54772a 100755
--- a/scripts-dev/make_full_schema.sh
+++ b/scripts-dev/make_full_schema.sh
@@ -19,7 +19,8 @@ usage() {
echo "-c"
echo " CI mode. Prints every command that the script runs."
echo "-o "
- echo " Directory to output full schema files to."
+ echo " Directory to output full schema files to. You probably want to use"
+ echo " '-o synapse/storage/schema'"
echo "-n "
echo " Schema number for the new snapshot. Used to set the location of files within "
echo " the output directory, mimicking that of synapse/storage/schemas."
@@ -27,6 +28,11 @@ usage() {
echo "-h"
echo " Display this help text."
echo ""
+ echo ""
+ echo "You probably want to invoke this with something like"
+ echo " docker run --rm -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres -e POSTGRES_DB=synapse -p 5432:5432 postgres:11-alpine"
+ echo " echo postgres | scripts-dev/make_full_schema.sh -p postgres -n MY_SCHEMA_NUMBER -o synapse/storage/schema"
+ echo ""
echo " NB: make sure to run this against the *oldest* supported version of postgres,"
echo " or else pg_dump might output non-backwards-compatible syntax."
}
@@ -189,7 +195,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
-synapse/_scripts/update_synapse_database.py --database-config "$SQLITE_CONFIG" --run-background-updates
+poetry run python synapse/_scripts/update_synapse_database.py --database-config "$SQLITE_CONFIG" --run-background-updates
# Create the PostgreSQL database.
echo "Creating postgres databases..."
@@ -198,7 +204,7 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_MAIN_DB_NAM
createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_STATE_DB_NAME"
echo "Running db background jobs..."
-synapse/_scripts/update_synapse_database.py --database-config "$POSTGRES_CONFIG" --run-background-updates
+poetry run python synapse/_scripts/update_synapse_database.py --database-config "$POSTGRES_CONFIG" --run-background-updates
echo "Dropping unwanted db tables..."
@@ -293,4 +299,12 @@ pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owne
pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres"
pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres"
+if [[ "$OUTPUT_DIR" == *synapse/storage/schema ]]; then
+ echo "Updating contrib/datagrip symlinks..."
+ ln -sf "../../synapse/storage/schema/common/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/common.sql"
+ ln -sf "../../synapse/storage/schema/main/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/main.sql"
+ ln -sf "../../synapse/storage/schema/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/state.sql"
+else
+ echo "Not updating contrib/datagrip symlinks (unknown output directory)"
+fi
echo "Done! Files dumped to: $OUTPUT_DIR"
diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi
index 754acab2f9..a8f0ed2435 100644
--- a/stubs/synapse/synapse_rust/push.pyi
+++ b/stubs/synapse/synapse_rust/push.pyi
@@ -14,7 +14,7 @@
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonValue
class PushRule:
@property
@@ -56,17 +56,18 @@ def get_base_rule_ids() -> Collection[str]: ...
class PushRuleEvaluator:
def __init__(
self,
- flattened_keys: Mapping[str, str],
+ flattened_keys: Mapping[str, JsonValue],
has_mentions: bool,
user_mentions: Set[str],
- room_mention: bool,
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
- related_events_flattened: Mapping[str, Mapping[str, str]],
+ related_events_flattened: Mapping[str, Mapping[str, JsonValue]],
related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
+ msc3758_exact_event_match: bool,
+ msc3966_exact_event_property_contains: bool,
): ...
def run(
self,
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5e137dbbf7..0d35e0af8f 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -94,61 +94,80 @@ reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
+# SQLite doesn't have a dedicated boolean type (it stores True/False as 1/0). This means
+# portdb will read sqlite bools as integers, then try to insert them into postgres
+# boolean columns---which fails. Lacking some Python-parseable metaschema, we must
+# specify which integer columns should be inserted as booleans into postgres.
BOOLEAN_COLUMNS = {
- "events": ["processed", "outlier", "contains_url"],
- "rooms": ["is_public", "has_auth_chain_index"],
+ "access_tokens": ["used"],
+ "account_validity": ["email_sent"],
+ "device_lists_changes_in_room": ["converted_to_destinations"],
+ "device_lists_outbound_pokes": ["sent"],
+ "devices": ["hidden"],
+ "e2e_fallback_keys_json": ["used"],
+ "e2e_room_keys": ["is_verified"],
"event_edges": ["is_state"],
+ "events": ["processed", "outlier", "contains_url"],
+ "local_media_repository": ["safe_from_quarantine"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
- "devices": ["hidden"],
- "device_lists_outbound_pokes": ["sent"],
- "users_who_share_rooms": ["share_private"],
- "e2e_room_keys": ["is_verified"],
- "account_validity": ["email_sent"],
+ "pushers": ["enabled"],
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
- "local_media_repository": ["safe_from_quarantine"],
+ "rooms": ["is_public", "has_auth_chain_index"],
"users": ["shadow_banned", "approved"],
- "e2e_fallback_keys_json": ["used"],
- "access_tokens": ["used"],
- "device_lists_changes_in_room": ["converted_to_destinations"],
- "pushers": ["enabled"],
+ "un_partial_stated_event_stream": ["rejection_status_changed"],
+ "users_who_share_rooms": ["share_private"],
}
+# These tables are never deleted from in normal operation [*], so we can resume porting
+# over rows from a previous attempt rather than starting from scratch.
+#
+# [*]: We do delete from many of these tables when purging a room, and
+# presumably when purging old events. So we might e.g.
+#
+# 1. Run portdb and port half of some table.
+# 2. Stop portdb.
+# 3. Purge something, deleting some of the rows we've ported over.
+# 4. Restart portdb. The rows deleted from sqlite are still present in postgres.
+#
+# But this isn't the end of the world: we should be able to repeat the purge
+# on the postgres DB when porting completes.
APPEND_ONLY_TABLES = [
- "event_reference_hashes",
- "events",
+ "cache_invalidation_stream_by_instance",
+ "event_auth",
+ "event_edges",
"event_json",
- "state_events",
- "room_memberships",
- "topics",
- "room_names",
- "rooms",
+ "event_reference_hashes",
+ "event_search",
+ "event_to_state_groups",
+ "events",
+ "ex_outlier_stream",
"local_media_repository",
"local_media_repository_thumbnails",
+ "presence_stream",
+ "public_room_list_stream",
+ "push_rules_stream",
+ "received_transactions",
+ "redactions",
+ "rejections",
"remote_media_cache",
"remote_media_cache_thumbnails",
- "redactions",
- "event_edges",
- "event_auth",
- "received_transactions",
+ "room_memberships",
+ "room_names",
+ "rooms",
"sent_transactions",
- "transaction_id_to_pdu",
- "users",
+ "state_events",
+ "state_group_edges",
"state_groups",
"state_groups_state",
- "event_to_state_groups",
- "rejections",
- "event_search",
- "presence_stream",
- "push_rules_stream",
- "ex_outlier_stream",
- "cache_invalidation_stream_by_instance",
- "public_room_list_stream",
- "state_group_edges",
"stream_ordering_to_exterm",
+ "topics",
+ "transaction_id_to_pdu",
+ "un_partial_stated_event_stream",
+ "users",
]
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3d7f986ac7..66e869bc2d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import (
- SynapseTags,
active_span,
force_tracing,
start_active_span,
@@ -162,12 +161,6 @@ class Auth:
parent_span.set_tag(
"authenticated_entity", requester.authenticated_entity
)
- # We tag the Synapse instance name so that it's an easy jumping
- # off point into the logs. Can also be used to filter for an
- # instance that is under load.
- parent_span.set_tag(
- SynapseTags.INSTANCE_NAME, self.hs.get_instance_name()
- )
parent_span.set_tag("user_id", requester.user.to_string())
if requester.device_id is not None:
parent_span.set_tag("device_id", requester.device_id)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c2c177fd71..e1737de59b 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -108,6 +108,10 @@ class Codes(str, Enum):
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
+ # Attempt to send a second annotation with the same event type & annotation key
+ # MSC2677
+ DUPLICATE_ANNOTATION = "M_DUPLICATE_ANNOTATION"
+
class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.
@@ -751,3 +755,25 @@ class ModuleFailedException(Exception):
Raised when a module API callback fails, for example because it raised an
exception.
"""
+
+
+class PartialStateConflictError(SynapseError):
+ """An internal error raised when attempting to persist an event with partial state
+ after the room containing the event has been un-partial stated.
+
+ This error should be handled by recomputing the event context and trying again.
+
+ This error has an HTTP status code so that it can be transported over replication.
+ It should not be exposed to clients.
+ """
+
+ @staticmethod
+ def message() -> str:
+ return "Cannot persist partial state event in un-partial stated room"
+
+ def __init__(self) -> None:
+ super().__init__(
+ HTTPStatus.CONFLICT,
+ msg=PartialStateConflictError.message(),
+ errcode=Codes.UNKNOWN,
+ )
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 83c42fc25a..b9f432cc23 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -219,9 +219,13 @@ class FilterCollection:
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
- self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
+ self._room_account_data_filter = Filter(
+ hs, room_filter_json.get("account_data", {})
+ )
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
- self._account_data = Filter(hs, filter_json.get("account_data", {}))
+ self._global_account_data_filter = Filter(
+ hs, filter_json.get("account_data", {})
+ )
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
@@ -256,8 +260,10 @@ class FilterCollection:
) -> List[UserPresenceState]:
return await self._presence_filter.filter(presence_states)
- async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
- return await self._account_data.filter(events)
+ async def filter_global_account_data(
+ self, events: Iterable[JsonDict]
+ ) -> List[JsonDict]:
+ return await self._global_account_data_filter.filter(events)
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return await self._room_state_filter.filter(
@@ -279,7 +285,7 @@ class FilterCollection:
async def filter_room_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
- return await self._room_account_data.filter(
+ return await self._room_account_data_filter.filter(
await self._room_filter.filter(events)
)
@@ -292,6 +298,13 @@ class FilterCollection:
or self._presence_filter.filters_all_senders()
)
+ def blocks_all_global_account_data(self) -> bool:
+ """True if all global acount data will be filtered out."""
+ return (
+ self._global_account_data_filter.filters_all_types()
+ or self._global_account_data_filter.filters_all_senders()
+ )
+
def blocks_all_room_ephemeral(self) -> bool:
return (
self._room_ephemeral_filter.filters_all_types()
@@ -299,6 +312,13 @@ class FilterCollection:
or self._room_ephemeral_filter.filters_all_rooms()
)
+ def blocks_all_room_account_data(self) -> bool:
+ return (
+ self._room_account_data_filter.filters_all_types()
+ or self._room_account_data_filter.filters_all_senders()
+ or self._room_account_data_filter.filters_all_rooms()
+ )
+
def blocks_all_room_timeline(self) -> bool:
return (
self._room_timeline_filter.filters_all_types()
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index fe7afb9475..ad51f33165 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -17,7 +17,7 @@ import logging
import os
import sys
import tempfile
-from typing import List, Optional
+from typing import List, Mapping, Optional
from twisted.internet import defer, task
@@ -222,6 +222,19 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(connection_file, "a") as f:
print(json.dumps(connection), file=f)
+ def write_account_data(
+ self, file_name: str, account_data: Mapping[str, JsonDict]
+ ) -> None:
+ account_data_directory = os.path.join(
+ self.base_directory, "user_data", "account_data"
+ )
+ os.makedirs(account_data_directory, exist_ok=True)
+
+ account_data_file = os.path.join(account_data_directory, file_name)
+
+ with open(account_data_file, "a") as f:
+ print(json.dumps(account_data), file=f)
+
def finished(self) -> str:
return self.base_directory
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 53db1e85b3..897dd3edac 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -15,7 +15,7 @@ import logging
import math
import resource
import sys
-from typing import TYPE_CHECKING, List, Sized, Tuple
+from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from prometheus_client import Gauge
@@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
- current_mau_count_by_service = {}
+ current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53c0682dfd..54c91953e1 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -169,12 +169,28 @@ class ExperimentalConfig(Config):
# MSC3925: do not replace events with their edits
self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False)
- # MSC3952: Intentional mentions
- self.msc3952_intentional_mentions = experimental.get(
- "msc3952_intentional_mentions", False
+ # MSC3758: exact_event_match push rule condition
+ self.msc3758_exact_event_match = experimental.get(
+ "msc3758_exact_event_match", False
+ )
+
+ # MSC3873: Disambiguate event_match keys.
+ self.msc3783_escape_event_match_key = experimental.get(
+ "msc3783_escape_event_match_key", False
+ )
+
+ # MSC3952: Intentional mentions, this depends on MSC3758.
+ self.msc3952_intentional_mentions = (
+ experimental.get("msc3952_intentional_mentions", False)
+ and self.msc3758_exact_event_match
)
# MSC3959: Do not generate notifications for edits.
self.msc3958_supress_edit_notifs = experimental.get(
"msc3958_supress_edit_notifs", False
)
+
+ # MSC3966: exact_event_property_contains push rule condition.
+ self.msc3966_exact_event_property_contains = experimental.get(
+ "msc3966_exact_event_property_contains", False
+ )
diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index b42dd2e93a..e6a75be434 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -33,4 +33,5 @@ class RedisConfig(Config):
self.redis_host = redis_config.get("host", "localhost")
self.redis_port = redis_config.get("port", 6379)
+ self.redis_dbid = redis_config.get("dbid", None)
self.redis_password = redis_config.get("password")
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 3ed236217f..8666c22f01 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, Collection
from matrix_common.regex import glob_to_regex
@@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
return False
def is_publishing_room_allowed(
- self, user_id: str, room_id: str, aliases: List[str]
+ self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool:
"""Checks if the given user is allowed to publish the room
@@ -122,7 +122,7 @@ class _RoomDirectoryRule:
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e
- def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
+ def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ecdaa2d9dd..d4ef9930b0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -177,6 +177,7 @@ KNOWN_RESOURCES = {
"client",
"consent",
"federation",
+ "health",
"keys",
"media",
"metrics",
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e0be9f88cc..4d6d1b8ebd 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -16,18 +16,7 @@
import collections.abc
import logging
import typing
-from typing import (
- Any,
- Collection,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -56,7 +45,13 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+ MutableStateMap,
+ StateMap,
+ StrCollection,
+ UserID,
+ get_domain_from_id,
+)
if typing.TYPE_CHECKING:
# conditional imports to avoid import cycle
@@ -69,7 +64,7 @@ logger = logging.getLogger(__name__)
class _EventSourceStore(Protocol):
async def get_events(
self,
- event_ids: Collection[str],
+ event_ids: StrCollection,
redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False,
allow_rejected: bool = False,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8aca9a3ab9..91118a8d84 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64
from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StrCollection
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
@@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta):
"""
return [e for e, _ in self._dict["prev_events"]]
- def auth_event_ids(self) -> Sequence[str]:
+ def auth_event_ids(self) -> StrCollection:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -558,7 +558,7 @@ class FrozenEventV2(EventBase):
"""
return self._dict["prev_events"]
- def auth_event_ids(self) -> Sequence[str]:
+ def auth_event_ids(self) -> StrCollection:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 94dd1298e1..c82745275f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
- prev_event_ids: List[str],
+ prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
- prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+ prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple
import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
from synapse.types.state import StateFilter
+class UnpersistedEventContextBase(ABC):
+ """
+ This is a base class for EventContext and UnpersistedEventContext, objects which
+ hold information relevant to storing an associated event. Note that an
+ UnpersistedEventContexts must be converted into an EventContext before it is
+ suitable to send to the db with its associated event.
+
+ Attributes:
+ _storage: storage controllers for interfacing with the database
+ app_service: If the associated event is being sent by a (local) application service, that
+ app service.
+ """
+
+ def __init__(self, storage_controller: "StorageControllers"):
+ self._storage: "StorageControllers" = storage_controller
+ self.app_service: Optional[ApplicationService] = None
+
+ @abstractmethod
+ async def persist(
+ self,
+ event: EventBase,
+ ) -> "EventContext":
+ """
+ A method to convert an UnpersistedEventContext to an EventContext, suitable for
+ sending to the database with the associated event.
+ """
+ pass
+
+ @abstractmethod
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state at the event (ie not including the event if the event is a
+ state event).
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB, example:
+ EventTypes.JoinRules
+ """
+ pass
+
+
@attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
"""
Holds information relevant to persisting an event
@@ -77,9 +121,6 @@ class EventContext:
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
and ``state_group``.
- app_service: If this event is being sent by a (local) application service, that
- app service.
-
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""
@@ -122,6 +163,9 @@ class EventContext:
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
+ async def persist(self, event: EventBase) -> "EventContext":
+ return self
+
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
)
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+ """
+ The event context holds information about the state groups for an event. It is important
+ to remember that an event technically has two state groups: the state group before the
+ event, and the state group after the event. If the event is not a state event, the state
+ group will not change (ie the state group before the event will be the same as the state
+ group after the event), but if it is a state event the state group before the event
+ will differ from the state group after the event.
+ This is a version of an EventContext before the new state group (if any) has been
+ computed and stored. It contains information about the state before the event (which
+ also may be the information after the event, if the event is not a state event). The
+ UnpersistedEventContext must be converted into an EventContext by calling the method
+ 'persist' on it before it is suitable to be sent to the DB for processing.
+
+ state_group_after_event:
+ The state group after the event. This will always be None until it is persisted.
+ If the event is not a state event, this will be the same as
+ state_group_before_event.
+
+ state_group_before_event:
+ The ID of the state group representing the state of the room before this event.
+
+ state_delta_due_to_event:
+ If the event is a state event, then this is the delta of the state between
+ `state_group` and `state_group_before_event`
+
+ prev_group_for_state_group_before_event:
+ If it is known, ``state_group_before_event``'s previous state group.
+
+ delta_ids_to_state_group_before_event:
+ If ``prev_group_for_state_group_before_event`` is not None, the state delta
+ between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+ partial_state:
+ Whether the event has partial state.
+
+ state_map_before_event:
+ A map of the state before the event, i.e. the state at `state_group_before_event`
+ """
+
+ _storage: "StorageControllers"
+ state_group_before_event: Optional[int]
+ state_group_after_event: Optional[int]
+ state_delta_due_to_event: Optional[dict]
+ prev_group_for_state_group_before_event: Optional[int]
+ delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+ partial_state: bool
+ state_map_before_event: Optional[StateMap[str]] = None
+
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state map, excluding this event.
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB
+
+ Returns:
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
+ """
+ if self.state_map_before_event:
+ return self.state_map_before_event
+
+ assert self.state_group_before_event is not None
+ return await self._storage.state.get_state_ids_for_group(
+ self.state_group_before_event, state_filter
+ )
+
+ async def persist(self, event: EventBase) -> EventContext:
+ """
+ Creates a full `EventContext` for the event, persisting any referenced state that
+ has not yet been persisted.
+
+ Args:
+ event: event that the EventContext is associated with.
+
+ Returns: An EventContext suitable for sending to the database with the event
+ for persisting
+ """
+ assert self.partial_state is not None
+
+ # If we have a full set of state for before the event but don't have a state
+ # group for that state, we need to get one
+ if self.state_group_before_event is None:
+ assert self.state_map_before_event
+ state_group_before_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.prev_group_for_state_group_before_event,
+ delta_ids=self.delta_ids_to_state_group_before_event,
+ current_state_ids=self.state_map_before_event,
+ )
+ self.state_group_before_event = state_group_before_event
+
+ # if the event isn't a state event the state group doesn't change
+ if not self.state_delta_due_to_event:
+ state_group_after_event = self.state_group_before_event
+
+ # otherwise if it is a state event we need to get a state group for it
+ else:
+ state_group_after_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ current_state_ids=None,
+ )
+
+ return EventContext.with_state(
+ storage=self._storage,
+ state_group=state_group_after_event,
+ state_group_before_event=self.state_group_before_event,
+ state_delta_due_to_event=self.state_delta_due_to_event,
+ partial_state=self.partial_state,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ )
+
+
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
async def check_event_allowed(
- self, event: EventBase, context: EventContext
+ self,
+ event: EventBase,
+ context: UnpersistedEventContextBase,
) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 0ac85a3be7..7d04560dca 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -884,7 +884,7 @@ class FederationClient(FederationBase):
if 500 <= e.code < 600:
failover = True
- elif e.code == 400 and synapse_error.errcode in failover_errcodes:
+ elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
failover = True
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
@@ -999,14 +999,13 @@ class FederationClient(FederationBase):
return destination, ev, room_version
+ failover_errcodes = {Codes.NOT_FOUND}
# MSC3083 defines additional error codes for room joins. Unfortunately
# we do not yet know the room version, assume these will only be returned
# by valid room versions.
- failover_errcodes = (
- (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
- if membership == Membership.JOIN
- else None
- )
+ if membership == Membership.JOIN:
+ failover_errcodes.add(Codes.UNABLE_AUTHORISE_JOIN)
+ failover_errcodes.add(Codes.UNABLE_TO_GRANT_JOIN)
return await self._try_destination_list(
"make_" + membership,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8d36172484..6d99845de5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -47,6 +48,7 @@ from synapse.api.errors import (
FederationError,
IncompatibleRoomVersionError,
NotFoundError,
+ PartialStateConflictError,
SynapseError,
UnsupportedRoomVersionError,
)
@@ -80,7 +82,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
@@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
- summary: Dict[str, MemberSummary],
+ summary: Mapping[str, MemberSummary],
) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..797de46dbc 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
}
)
- (
- account_data,
- room_account_data,
- ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ account_data = await self.store.get_updated_global_account_data_for_user(
+ user_id, last_stream_id
+ )
+ room_account_data = await self.store.get_updated_room_account_data_for_user(
+ user_id, last_stream_id
+ )
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b03c214b14..8b7760b2cc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,7 +14,7 @@
import abc
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
+ self._store = hs.get_datastores().main
self._device_handler = hs.get_device_handler()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -38,7 +38,7 @@ class AdminHandler:
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
- sessions = await self.store.get_user_ip_and_agents(user)
+ sessions = await self._store.get_user_ip_and_agents(user)
for session in sessions:
connections.append(
{
@@ -57,7 +57,7 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
- user_info_dict = await self.store.get_user_by_id(user.to_string())
+ user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
return None
@@ -89,11 +89,11 @@ class AdminHandler:
}
# Add additional user metadata
- profile = await self.store.get_profileinfo(user.localpart)
- threepids = await self.store.user_get_threepids(user.to_string())
+ profile = await self._store.get_profileinfo(user.localpart)
+ threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
- for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ for auth_provider, external_id in await self._store.get_external_ids_by_user(
user.to_string()
)
]
@@ -101,7 +101,7 @@ class AdminHandler:
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["external_ids"] = external_ids
- user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
+ user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
return user_info_dict
@@ -117,7 +117,7 @@ class AdminHandler:
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
- rooms = await self.store.get_rooms_for_local_user_where_membership_is(
+ rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
@@ -131,7 +131,7 @@ class AdminHandler:
# We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle
# those separately.
- rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
+ rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms):
room_id = room.room_id
@@ -140,7 +140,7 @@ class AdminHandler:
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
)
- forgotten = await self.store.did_forget(user_id, room_id)
+ forgotten = await self._store.did_forget(user_id, room_id)
if forgotten:
logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue
@@ -152,14 +152,14 @@ class AdminHandler:
if room.membership == Membership.INVITE:
event_id = room.event_id
- invite = await self.store.get_event(event_id, allow_none=True)
+ invite = await self._store.get_event(event_id, allow_none=True)
if invite:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
if room.membership == Membership.KNOCK:
event_id = room.event_id
- knock = await self.store.get_event(event_id, allow_none=True)
+ knock = await self._store.get_event(event_id, allow_none=True)
if knock:
knock_state = knock.unsigned["knock_room_state"]
writer.write_knock(room_id, knock, knock_state)
@@ -170,7 +170,7 @@ class AdminHandler:
# were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN:
- stream_ordering = self.store.get_room_max_stream_ordering()
+ stream_ordering = self._store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
@@ -197,7 +197,7 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
- events, _ = await self.store.paginate_room_events(
+ events, _ = await self._store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
@@ -263,6 +263,13 @@ class AdminHandler:
connections["devices"][""]["sessions"][0]["connections"]
)
+ # Get all account data the user has global and in rooms
+ global_data = await self._store.get_global_account_data_for_user(user_id)
+ by_room_data = await self._store.get_room_account_data_for_user(user_id)
+ writer.write_account_data("global", global_data)
+ for room_id in by_room_data:
+ writer.write_account_data(room_id, by_room_data[room_id])
+
return writer.finished()
@@ -340,6 +347,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ @abc.abstractmethod
+ def write_account_data(
+ self, file_name: str, account_data: Mapping[str, JsonDict]
+ ) -> None:
+ """Write the account data of a user.
+
+ Args:
+ file_name: file name to write data
+ account_data: mapping of global or room account_data
+ """
+ raise NotImplementedError()
+
@abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 30f2d46c3c..cf12b55d21 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -201,7 +201,7 @@ class AuthHandler:
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
- self.checkers[inst.AUTH_TYPE] = inst # type: ignore
+ self.checkers[inst.AUTH_TYPE] = inst
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
@@ -1593,9 +1593,8 @@ class AuthHandler:
if medium == "email":
address = canonicalise_email(address)
- identity_handler = self.hs.get_identity_handler()
- result = await identity_handler.try_unbind_threepid(
- user_id, {"medium": medium, "address": address, "id_server": id_server}
+ result = await self.hs.get_identity_handler().try_unbind_threepid(
+ user_id, medium, address, id_server
)
await self.store.user_delete_threepid(user_id, medium, address)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d74d135c0c..d24f649382 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -106,12 +106,7 @@ class DeactivateAccountHandler:
for threepid in threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
- user_id,
- {
- "medium": threepid["medium"],
- "address": threepid["address"],
- "id_server": id_server,
- },
+ user_id, threepid["medium"], threepid["address"], id_server
)
identity_server_supports_unbinding &= result
except Exception:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..a5798e9483 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
import logging
import string
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal
@@ -485,7 +485,8 @@ class DirectoryHandler:
)
)
if canonical_alias:
- room_aliases.append(canonical_alias)
+ # Ensure we do not mutate room_aliases.
+ room_aliases = list(room_aliases) + [canonical_alias]
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
@@ -528,7 +529,7 @@ class DirectoryHandler:
async def get_aliases_for_room(
self, requester: Requester, room_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..43cbece21b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
- query_list: List[Tuple[str, Optional[str]]] = []
+ user_ids = set()
+ user_and_device_ids: List[Tuple[str, str]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
- query_list.extend(
+ user_and_device_ids.extend(
(user_id, device_id) for device_id in device_ids
)
else:
- query_list.append((user_id, None))
+ user_ids.add(user_id)
(
user_ids_not_in_cache,
remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
+ ) = await self.store.get_user_devices_from_cache(
+ user_ids, user_and_device_ids
+ )
# Check that the homeserver still shares a room with all cached users.
# Note that this check may be slightly racy when a remote user leaves a
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..46dd63c3f0 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -202,7 +202,7 @@ class EventAuthHandler:
state_ids: StateMap[str],
room_version: RoomVersion,
user_id: str,
- prev_member_event: Optional[EventBase],
+ prev_membership: Optional[str],
) -> None:
"""
Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +214,14 @@ class EventAuthHandler:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
- prev_member_event: The current membership event for this user.
+ prev_membership: The current membership state for this user. `None` if the
+ user has never joined the room (equivalent to "leave").
Raises:
AuthError if the user cannot join the room.
"""
# If the member is invited or currently joined, then nothing to do.
- if prev_member_event and (
- prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
- ):
+ if prev_membership in (Membership.JOIN, Membership.INVITE):
return
# This is not a room with a restricted join rule, so we don't need to do the
@@ -255,13 +254,14 @@ class EventAuthHandler:
)
async def has_restricted_join_rules(
- self, state_ids: StateMap[str], room_version: RoomVersion
+ self, partial_state_ids: StateMap[str], room_version: RoomVersion
) -> bool:
"""
Return if the room has the proper join rules set for access via rooms.
Args:
- state_ids: The state of the room as it currently is.
+ state_ids: The state of the room as it currently is. May be full or partial
+ state.
room_version: The room version of the room to query.
Returns:
@@ -272,7 +272,7 @@ class EventAuthHandler:
return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
- join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..5f2057269d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
FederationPullAttemptBackoffError,
HttpResponseException,
NotFoundError,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
@@ -952,7 +952,20 @@ class FederationHandler:
#
# Note that this requires the /send_join request to come back to the
# same server.
+ prev_event_ids = None
if room_version.msc3083_join_rules:
+ # Note that the room's state can change out from under us and render our
+ # nice join rules-conformant event non-conformant by the time we build the
+ # event. When this happens, our validation at the end fails and we respond
+ # to the requesting server with a 403, which is misleading — it indicates
+ # that the user is not allowed to join the room and the joining server
+ # should not bother retrying via this homeserver or any others, when
+ # in fact we've just messed up with building the event.
+ #
+ # To reduce the likelihood of this race, we capture the forward extremities
+ # of the room (prev_event_ids) just before fetching the current state, and
+ # hope that the state we fetch corresponds to the prev events we chose.
+ prev_event_ids = await self.store.get_prev_events_for_room(room_id)
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id
)
@@ -990,15 +1003,21 @@ class FederationHandler:
)
try:
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ prev_event_ids=prev_event_ids,
)
except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e)
raise
# Ensure the user can even join the room.
- await self._federation_event_handler.check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(
+ unpersisted_context, event
+ )
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1178,7 +1197,7 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
+ event, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1228,12 +1247,13 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
+ event, unpersisted_context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1426,20 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
@@ -1483,14 +1508,19 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1552,8 @@ class FederationHandler:
room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
- context: EventContext,
- ) -> Tuple[EventBase, EventContext]:
+ context: UnpersistedEventContextBase,
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1587,14 @@ class FederationHandler:
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
EventValidator().validate_new(event, self.config)
- return event, context
+ return event, unpersisted_context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
@@ -1861,6 +1894,11 @@ class FederationHandler:
logger.info("Updating current state for %s", room_id)
# TODO(faster_joins): notify workers in notify_room_un_partial_stated
# https://github.com/matrix-org/synapse/issues/12994
+ #
+ # NB: there's a potential race here. If room is purged just before we
+ # call this, we _might_ end up inserting rows into current_state_events.
+ # (The logic is hard to chase through.) We think this is fine, but if
+ # not the HS admin should purge the room again.
await self.state_handler.update_current_state(room_id)
logger.info("Handling any pending device list updates")
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..b7136f8d1c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -58,7 +59,7 @@ from synapse.event_auth import (
validate_event_for_room_version,
)
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
return event, context
async def check_join_restrictions(
- self, context: EventContext, event: EventBase
+ self,
+ context: UnpersistedEventContextBase,
+ event: EventBase,
) -> None:
"""Check that restrictions in restricted join rules are matched
@@ -439,16 +441,17 @@ class FederationEventHandler:
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
+ prev_membership = None
if prev_member_event_id:
prev_member_event = await self._store.get_event(prev_member_event_id)
+ prev_membership = prev_member_event.membership
# Check if the member should be allowed access via membership in a space.
await self._event_auth_handler.check_restricted_join_rules(
prev_state_ids,
event.room_version,
user_id,
- prev_member_event,
+ prev_membership,
)
@trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id):
+ if partial_state:
+ # When handling a second partial state join into a partial state room,
+ # the returned state will exclude the membership from the first join. To
+ # preserve prior memberships, we try to compute the partial state before
+ # the event ourselves if we know about any of the prev events.
+ #
+ # When we don't know about any of the prev events, it's fine to just use
+ # the returned state, since the new join will create a new forward
+ # extremity, and leave the forward extremity containing our prior
+ # memberships alone.
+ prev_event_ids = set(event.prev_event_ids())
+ seen_event_ids = await self._store.have_events_in_timeline(
+ prev_event_ids
+ )
+ missing_event_ids = prev_event_ids - seen_event_ids
+
+ state_maps_to_resolve: List[StateMap[str]] = []
+
+ # Fetch the state after the prev events that we know about.
+ state_maps_to_resolve.extend(
+ (
+ await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen_event_ids, await_full_state=False
+ )
+ ).values()
+ )
+
+ # When there are prev events we do not have the state for, we state
+ # resolve with the state returned by the remote homeserver.
+ if missing_event_ids or len(state_maps_to_resolve) == 0:
+ state_maps_to_resolve.append(
+ {(e.type, e.state_key): e.event_id for e in state}
+ )
+
+ state_ids_before_event = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version.identifier,
+ state_maps_to_resolve,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
+ )
+ else:
+ state_ids_before_event = {
+ (e.type, e.state_key): e.event_id for e in state
+ }
+
context = await self._state_handler.compute_event_context(
event,
- state_ids_before_event={
- (e.type, e.state_key): e.event_id for e in state
- },
+ state_ids_before_event=state_ids_before_event,
partial_state=partial_state,
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 848e46eb9b..bf0f7acf80 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -219,28 +219,31 @@ class IdentityHandler:
data = json_decoder.decode(e.msg) # XXX WAT?
return data
- async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
- """Attempt to remove a 3PID from an identity server, or if one is not provided, all
- identity servers we're aware the binding is present on
+ async def try_unbind_threepid(
+ self, mxid: str, medium: str, address: str, id_server: Optional[str]
+ ) -> bool:
+ """Attempt to remove a 3PID from one or more identity servers.
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be
- removed, and an optional id_server.
+ medium: The medium of the third-party ID.
+ address: The address of the third-party ID.
+ id_server: An identity server to attempt to unbind from. If None,
+ attempt to remove the association from all identity servers
+ known to potentially have it.
Raises:
- SynapseError: If we failed to contact the identity server
+ SynapseError: If we failed to contact one or more identity servers.
Returns:
- True on success, otherwise False if the identity
- server doesn't support unbinding (or no identity server found to
- contact).
+ True on success, otherwise False if the identity server doesn't
+ support unbinding (or no identity server to contact was found).
"""
- if threepid.get("id_server"):
- id_servers = [threepid["id_server"]]
+ if id_server:
+ id_servers = [id_server]
else:
id_servers = await self.store.get_id_servers_user_bound(
- user_id=mxid, medium=threepid["medium"], address=threepid["address"]
+ mxid, medium, address
)
# We don't know where to unbind, so we don't have a choice but to return
@@ -249,20 +252,21 @@ class IdentityHandler:
changed = True
for id_server in id_servers:
- changed &= await self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server
+ changed &= await self._try_unbind_threepid_with_id_server(
+ mxid, medium, address, id_server
)
return changed
- async def try_unbind_threepid_with_id_server(
- self, mxid: str, threepid: dict, id_server: str
+ async def _try_unbind_threepid_with_id_server(
+ self, mxid: str, medium: str, address: str, id_server: str
) -> bool:
"""Removes a binding from an identity server
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be removed
+ medium: The medium of the third-party ID
+ address: The address of the third-party ID
id_server: Identity server to unbind from
Raises:
@@ -286,7 +290,7 @@ class IdentityHandler:
content = {
"mxid": mxid,
- "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
+ "threepid": {"medium": medium, "address": address},
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -319,12 +323,7 @@ class IdentityHandler:
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- await self.store.remove_user_bound_threepid(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
- id_server=id_server,
- )
+ await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
return changed
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..1a29abde98 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -154,9 +154,8 @@ class InitialSyncHandler:
tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = await self.store.get_account_data_for_user(
- user_id
- )
+ account_data = await self.store.get_global_account_data_for_user(user_id)
+ account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
public_room_ids = await self.store.get_public_room_ids()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e688e00575..aa90d0000d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
+ PartialStateConflictError,
ShadowBanError,
SynapseError,
UnstableSpecAuthError,
@@ -48,7 +49,7 @@ from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
MutableStateMap,
@@ -499,9 +499,9 @@ class EventCreationHandler:
self.request_ratelimiter = hs.get_request_ratelimiter()
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+ # We limit concurrent event creation for a room to 1. This prevents state resolution
+ # from occurring when sending bursts of events to a local room
+ self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
@@ -708,7 +708,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, context = await self.create_new_client_event(
+ event, unpersisted_context = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
+ context = await unpersisted_context.persist(event)
+
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
must be provided in this case if for_batch is True. The subsequently created event
and context are suitable for being batched up and bulk persisted to the database
- with other similarly created events.
+ with other similarly created events. Note that this returns an UnpersistedEventContext,
+ which must be converted to an EventContext before it can be sent to the DB.
Args:
builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, context
+ Tuple of created event, UnpersistedEventContext
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
)
- context = await self.state.compute_event_context_for_batched(
- event, state_map, current_state_group
+
+ context: UnpersistedEventContextBase = (
+ await self.state.calculate_context_info(
+ event,
+ state_ids_before_event=state_map,
+ partial_state=False,
+ state_group_before_event=current_state_group,
+ )
)
+
else:
event = await builder.build(
prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
state_map_for_event[(data.event_type, data.state_key)] = state_id
- context = await self.state.compute_event_context(
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ context = await self.state.calculate_context_info(
event,
state_ids_before_event=state_map_for_event,
- # TODO(faster_joins): check how MSC2716 works and whether we can have
- # partial state here
- # https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
)
+
else:
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
if requester:
context.app_service = requester.app_service
@@ -1326,7 +1337,11 @@ class EventCreationHandler:
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
- raise SynapseError(400, "Can't send same reaction twice")
+ raise SynapseError(
+ 400,
+ "Can't send same reaction twice",
+ errcode=Codes.DUPLICATE_ANNOTATION,
+ )
# Don't attempt to start a thread if the parent event is a relation.
elif relation.rel_type == RelationTypes.THREAD:
@@ -2082,9 +2097,9 @@ class EventCreationHandler:
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
# the third_party_event_rules want to replace the event.
- # we do some basic checks, and then return the replacement event and context.
+ # we do some basic checks, and then return the replacement event.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
@@ -2138,5 +2153,6 @@ class EventCreationHandler:
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
+
return event, context
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 04c61ae3dd..2bacdebfb5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
@staticmethod
def filter_out_private_receipts(
- rooms: List[JsonDict], user_id: str
+ rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..837dabb3b7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
LimitExceededError,
NotFoundError,
+ PartialStateConflictError,
StoreError,
SynapseError,
)
@@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -1076,7 +1076,7 @@ class RoomCreationHandler:
state_map: MutableStateMap[str] = {}
# current_state_group of last event created. Used for computing event context of
# events to be batched
- current_state_group = None
+ current_state_group: Optional[int] = None
def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
@@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
+ "local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 2db2054300..f9c240a948 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
GuestAccess,
Membership,
)
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ PartialStateConflictError,
+ ShadowBanError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.types import (
JsonDict,
Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class NoKnownServersError(SynapseError):
+ """No server already resident to the room was provided to the join/knock operation."""
+
+ def __init__(self, msg: str = "No known servers"):
+ super().__init__(404, msg)
+
+
class RoomMemberHandler(metaclass=abc.ABCMeta):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
@@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Room that we are trying to join
user: User who is trying to join
content: A dict that should be used as the content of the join event.
+
+ Raises:
+ NoKnownServersError: if remote_room_hosts does not contain a server joined to
+ the room.
"""
raise NotImplementedError()
@@ -484,7 +500,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+ user_account_data = await self.store.get_global_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
@@ -837,14 +853,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- state_before_join = await self.state_handler.compute_state_after_events(
- room_id, latest_event_ids
+ is_partial_state_room = await self.store.is_partial_state_room(room_id)
+ partial_state_before_join = await self.state_handler.compute_state_after_events(
+ room_id, latest_event_ids, await_full_state=False
)
+ # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+ # partial.
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+ old_state_id = partial_state_before_join.get(
+ (EventTypes.Member, target.to_string())
+ )
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -895,11 +916,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(state_before_join)
+ is_host_in_room = await self._is_host_in_room(partial_state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(state_before_join)
+ guest_can_join = await self._can_guest_join(partial_state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -941,8 +962,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id,
remote_room_hosts,
content,
+ is_partial_state_room,
is_host_in_room,
- state_before_join,
+ partial_state_before_join,
)
if remote_join:
if ratelimit:
@@ -1087,8 +1109,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: str,
remote_room_hosts: List[str],
content: JsonDict,
+ is_partial_state_room: bool,
is_host_in_room: bool,
- state_before_join: StateMap[str],
+ partial_state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -1107,9 +1130,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: A list of remote room hosts.
content: The content to use as the event body of the join. This may
be modified.
- is_host_in_room: True if the host is in the room.
- state_before_join: The state before the join event (i.e. the resolution of
- the states after its parent events).
+ is_partial_state_room: `True` if the server currently doesn't hold the full
+ state of the room.
+ is_host_in_room: `True` if the host is in the room.
+ partial_state_before_join: The state before the join event (i.e. the
+ resolution of the states after its parent events). May be full or
+ partial state, depending on `is_partial_state_room`.
Returns:
A tuple of:
@@ -1123,6 +1149,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if not is_host_in_room:
return True, remote_room_hosts
+ prev_member_event_id = partial_state_before_join.get(
+ (EventTypes.Member, user_id), None
+ )
+ previous_membership = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ previous_membership = prev_member_event.membership
+
+ # If we are not fully joined yet, and the target is not already in the room,
+ # let's do a remote join so another server with the full state can validate
+ # that the user has not been banned for example.
+ # We could just accept the join and wait for state res to resolve that later on
+ # but we would then leak room history to this person until then, which is pretty
+ # bad.
+ if is_partial_state_room and previous_membership != Membership.JOIN:
+ return True, remote_room_hosts
+
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
@@ -1130,21 +1173,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- state_before_join, room_version
+ partial_state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- if prev_member_event.membership in (
- Membership.JOIN,
- Membership.INVITE,
- ):
- return False, []
+ if previous_membership in (Membership.JOIN, Membership.INVITE):
+ return False, []
+
+ # All the partial state cases are covered above. We have been given the full
+ # state of the room.
+ assert not is_partial_state_room
+ state_before_join = partial_state_before_join
# If the local host has a user who can issue invites, then a local
# join can be done.
@@ -1168,7 +1209,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- state_before_join, room_version, user_id, prev_member_event
+ state_before_join, room_version, user_id, previous_membership
)
# If this is going to be a local join, additional information must
@@ -1318,11 +1359,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+ async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
"""
- guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+ guest_access_id = partial_current_state_ids.get(
+ (EventTypes.GuestAccess, ""), None
+ )
if not guest_access_id:
return False
@@ -1648,19 +1695,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
return event, stream_id
- async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+ async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+ """Returns whether the homeserver is in the room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
+ """
# Have we just created the room, and is this about to be the very
# first member event?
- create_event_id = current_state_ids.get(("m.room.create", ""))
- if len(current_state_ids) == 1 and create_event_id:
+ create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+ if len(partial_current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
return True
- for etype, state_key in current_state_ids:
+ for etype, state_key in partial_current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
- event_id = current_state_ids[(etype, state_key)]
+ event_id = partial_current_state_ids[(etype, state_key)]
event = await self.store.get_event(event_id, allow_none=True)
if not event:
continue
@@ -1729,8 +1782,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(
- 404,
+ raise NoKnownServersError(
"Can't join remote room because no servers "
"that are in the room have been provided.",
)
@@ -1961,7 +2013,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
return await self.federation_handler.do_knock(
remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..ba261702d4 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
ret = await self._remote_join_client(
requester=requester,
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4472019fbc..807245160d 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -521,8 +521,8 @@ class RoomSummaryHandler:
It should return true if:
- * The requester is joined or can join the room (per MSC3173).
- * The origin server has any user that is joined or can join the room.
+ * The requesting user is joined or can join the room (per MSC3173); or
+ * The origin server has any user that is joined or can join the room; or
* The history visibility is set to world readable.
Args:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..4e4595312c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -269,6 +269,8 @@ class SyncHandler:
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
+ self.should_calculate_push_rules = hs.config.push.enable_push
+
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
# that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1288,6 +1290,12 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> RoomNotifCounts:
+ if not self.should_calculate_push_rules:
+ # If push rules have been universally disabled then we know we won't
+ # have any unread counts in the DB, so we may as well skip asking
+ # the DB.
+ return RoomNotifCounts.empty()
+
with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1391,6 +1399,11 @@ class SyncHandler:
for room_id, is_partial_state in results.items()
if is_partial_state
)
+ membership_change_events = [
+ event
+ for event in membership_change_events
+ if not results.get(event.room_id, False)
+ ]
# Incremental eager syncs should additionally include rooms that
# - we are joined to
@@ -1444,9 +1457,9 @@ class SyncHandler:
logger.debug("Fetching account data")
- account_data_by_room = await self._generate_sync_entry_for_account_data(
- sync_result_builder
- )
+ # Global account data is included if it is not filtered out.
+ if not sync_config.filter_collection.blocks_all_global_account_data():
+ await self._generate_sync_entry_for_account_data(sync_result_builder)
# Presence data is included if the server has it enabled and not filtered out.
include_presence_data = bool(
@@ -1472,9 +1485,7 @@ class SyncHandler:
(
newly_joined_rooms,
newly_left_rooms,
- ) = await self._generate_sync_entry_for_rooms(
- sync_result_builder, account_data_by_room
- )
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# Work out which users have joined or left rooms we're in. We use this
# to build the presence and device_list parts of the sync response in
@@ -1521,7 +1532,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- unused_fallback_key_types = (
+ unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
@@ -1717,35 +1728,29 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
- ) -> Dict[str, Dict[str, JsonDict]]:
- """Generates the account data portion of the sync response.
+ ) -> None:
+ """Generates the global account data portion of the sync response.
Account data (called "Client Config" in the spec) can be set either globally
or for a specific room. Account data consists of a list of events which
accumulate state, much like a room.
- This function retrieves global and per-room account data. The former is written
- to the given `sync_result_builder`. The latter is returned directly, to be
- later written to the `sync_result_builder` on a room-by-room basis.
+ This function retrieves global account data and writes it to the given
+ `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+ of per-room account data.
Args:
sync_result_builder
-
- Returns:
- A dictionary whose keys (room ids) map to the per room account data for that
- room.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
+ global_account_data = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
)
push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1753,31 +1758,31 @@ class SyncHandler:
)
if push_rules_changed:
+ global_account_data = dict(global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+ all_global_account_data = await self.store.get_global_account_data_for_user(
+ user_id
+ )
+ global_account_data = dict(all_global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = await sync_config.filter_collection.filter_account_data(
- [
- {"type": account_data_type, "content": content}
- for account_data_type, content in global_account_data.items()
- ]
+ account_data_for_user = (
+ await sync_config.filter_collection.filter_global_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in global_account_data.items()
+ ]
+ )
)
sync_result_builder.account_data = account_data_for_user
- return account_data_by_room
-
async def _generate_sync_entry_for_presence(
self,
sync_result_builder: "SyncResultBuilder",
@@ -1837,9 +1842,7 @@ class SyncHandler:
sync_result_builder.presence = presence
async def _generate_sync_entry_for_rooms(
- self,
- sync_result_builder: "SyncResultBuilder",
- account_data_by_room: Dict[str, Dict[str, JsonDict]],
+ self, sync_result_builder: "SyncResultBuilder"
) -> Tuple[AbstractSet[str], AbstractSet[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1850,7 +1853,6 @@ class SyncHandler:
Args:
sync_result_builder
- account_data_by_room: Dictionary of per room account data
Returns:
Returns a 2-tuple describing rooms the user has joined or left.
@@ -1863,9 +1865,30 @@ class SyncHandler:
since_token = sync_result_builder.since_token
user_id = sync_result_builder.sync_config.user.to_string()
+ blocks_all_rooms = (
+ sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ )
+
+ # 0. Start by fetching room account data (if required).
+ if (
+ blocks_all_rooms
+ or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+ ):
+ account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+ elif since_token and not sync_result_builder.full_state:
+ account_data_by_room = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
+ )
+ else:
+ account_data_by_room = await self.store.get_room_account_data_for_user(
+ user_id
+ )
+
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
block_all_room_ephemeral = (
- sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ blocks_all_rooms
or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
@@ -2291,8 +2314,8 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
- tags: Optional[Dict[str, Dict[str, Any]]],
- account_data: Dict[str, JsonDict],
+ tags: Optional[Mapping[str, Mapping[str, Any]]],
+ account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 332edcca24..78a75bfed6 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,8 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
from twisted.web.client import PartialDownloadError
@@ -27,19 +28,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UserInteractiveAuthChecker:
+class UserInteractiveAuthChecker(ABC):
"""Abstract base class for an interactive auth checker"""
- def __init__(self, hs: "HomeServer"):
+ # This should really be an "abstract class property", i.e. it should
+ # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
+ # But calling this a `ClassVar` is simpler than a decorator stack of
+ # @property @abstractmethod and @classmethod (if that's even the right order).
+ AUTH_TYPE: ClassVar[str]
+
+ def __init__(self, hs: "HomeServer"): # noqa: B027
pass
+ @abstractmethod
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
True if this login type is enabled.
"""
+ raise NotImplementedError()
+ @abstractmethod
async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
@@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
)
-INTERACTIVE_AUTH_CHECKERS = [
+INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b92f1d3d1a..312aab4dcc 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1267,7 +1267,7 @@ class MatrixFederationHttpClient:
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
+ _flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2563858f3c..9314454af1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,7 +30,6 @@ from typing import (
Iterable,
Iterator,
List,
- NoReturn,
Optional,
Pattern,
Tuple,
@@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return callback_return
- return _unrecognised_request_handler(request)
+ # A request with an unknown method (for a known endpoint) was received.
+ raise UnrecognizedRequestError(code=405)
@abc.abstractmethod
def _send_response(
@@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PathEntry:
- pattern: Pattern
callback: ServletCallback
servlet_classname: str
@@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource):
):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
- self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
+ # Map of path regex -> method -> callback.
+ self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {}
self.hs = hs
def register_paths(
self,
method: str,
- path_patterns: Iterable[Pattern],
+ path_patterns: Iterable[Pattern[str]],
callback: ServletCallback,
servlet_classname: str,
) -> None:
@@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
- self.path_regexs.setdefault(method_bytes, []).append(
- _PathEntry(path_pattern, callback, servlet_classname)
+ self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry(
+ callback, servlet_classname
)
def _get_handler_for_request(
@@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource):
# Loop through all the registered callbacks to check if the method
# and path regex match
- for path_entry in self.path_regexs.get(request_method, []):
- m = path_entry.pattern.match(request_path)
+ for path_pattern, methods in self._routes.items():
+ m = path_pattern.match(request_path)
if m:
- # We found a match!
+ # We found a matching path!
+ path_entry = methods.get(request_method)
+ if not path_entry:
+ raise UnrecognizedRequestError(code=405)
return path_entry.callback, path_entry.servlet_classname, m.groupdict()
- # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, "unrecognised_request_handler", {}
+ # Huh. No one wanted to handle that? Fiiiiiine.
+ raise UnrecognizedRequestError(code=404)
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
@@ -567,19 +570,6 @@ class StaticResource(File):
return super().render_GET(request)
-def _unrecognised_request_handler(request: Request) -> NoReturn:
- """Request handler for unrecognised requests
-
- This is a request handler suitable for return from
- _get_handler_for_request. It actually just raises an
- UnrecognizedRequestError.
-
- Args:
- request: Unused, but passed in to match the signature of ServletCallback.
- """
- raise UnrecognizedRequestError(code=404)
-
-
class UnrecognizedRequestResource(resource.Resource):
"""
Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8ef9a0dda8..5aed71262f 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -188,7 +188,7 @@ from typing import (
)
import attr
-from typing_extensions import ParamSpec
+from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.web.http import Request
@@ -445,7 +445,7 @@ def init_tracer(hs: "HomeServer") -> None:
opentracing = None # type: ignore[assignment]
return
- if not opentracing or not JaegerConfig:
+ if opentracing is None or JaegerConfig is None:
raise ConfigError(
"The server has been configured to use opentracing but opentracing is not "
"installed."
@@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None:
STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name()
)
+ jaeger_config = hs.config.tracing.jaeger_config
+ tags = jaeger_config.setdefault("tags", {})
+
+ # tag the Synapse instance name so that it's an easy jumping
+ # off point into the logs. Can also be used to filter for an
+ # instance that is under load.
+ tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name()
+
config = JaegerConfig(
- config=hs.config.tracing.jaeger_config,
+ config=jaeger_config,
service_name=f"{hs.config.server.server_name} {instance_name_by_type}",
scope_manager=LogContextScopeManager(),
metrics_factory=PrometheusMetricsFactory(),
@@ -864,7 +872,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
def _custom_sync_async_decorator(
func: Callable[P, R],
- wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+ wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]],
) -> Callable[P, R]:
"""
Decorates a function that is sync or async (coroutines), or that returns a Twisted
@@ -894,10 +902,14 @@ def _custom_sync_async_decorator(
"""
if inspect.iscoroutinefunction(func):
-
+ # In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
- async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ async def _wrapper(
+ *args: P.args, **kwargs: P.kwargs
+ ) -> Any: # Return type is RInner
with wrapping_logic(func, *args, **kwargs):
+ # type-ignore: func() returns R, but mypy doesn't know that R is
+ # Awaitable here.
return await func(*args, **kwargs) # type: ignore[misc]
else:
@@ -964,7 +976,11 @@ def trace_with_opname(
if not opentracing:
return func
- return _custom_sync_async_decorator(func, _wrapping_logic)
+ # type-ignore: mypy seems to be confused by the ParamSpecs here.
+ # I think the problem is https://github.com/python/mypy/issues/12909
+ return _custom_sync_async_decorator(
+ func, _wrapping_logic # type: ignore[arg-type]
+ )
return _decorator
@@ -1010,7 +1026,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
yield
- return _custom_sync_async_decorator(func, _wrapping_logic)
+ # type-ignore: mypy seems to be confused by the ParamSpecs here.
+ # I think the problem is https://github.com/python/mypy/issues/12909
+ return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type]
@contextlib.contextmanager
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9c0a98f44..5fc38431ba 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -43,6 +44,7 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
+from synapse.types import JsonValue
from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
@@ -148,7 +150,7 @@ class BulkPushRuleEvaluator:
# little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster.
if event.type == EventTypes.Member:
- local_users = []
+ local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries.
@@ -183,7 +185,6 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
- local_users = list(local_users)
local_users.append(invited)
if not local_users:
@@ -256,13 +257,15 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+ async def _related_events(
+ self, event: EventBase
+ ) -> Dict[str, Dict[str, JsonValue]]:
"""Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
Mapping of relation type to flattened events.
"""
- related_events: Dict[str, Dict[str, str]] = {}
+ related_events: Dict[str, Dict[str, JsonValue]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
@@ -271,7 +274,10 @@ class BulkPushRuleEvaluator:
related_event_id, allow_none=True
)
if related_event is not None:
- related_events[relation_type] = _flatten_dict(related_event)
+ related_events[relation_type] = _flatten_dict(
+ related_event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ )
reply_event_id = (
event.content.get("m.relates_to", {})
@@ -286,7 +292,10 @@ class BulkPushRuleEvaluator:
)
if related_event is not None:
- related_events["m.in_reply_to"] = _flatten_dict(related_event)
+ related_events["m.in_reply_to"] = _flatten_dict(
+ related_event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ )
# indicate that this is from a fallback relation.
if relation_type == "m.thread" and event.content.get(
@@ -391,7 +400,6 @@ class BulkPushRuleEvaluator:
mentions = event.content.get(EventContentFields.MSC3952_MENTIONS)
has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict)
user_mentions: Set[str] = set()
- room_mention = False
if has_mentions:
# mypy seems to have lost the type even though it must be a dict here.
assert isinstance(mentions, dict)
@@ -401,14 +409,14 @@ class BulkPushRuleEvaluator:
user_mentions = set(
filter(lambda item: isinstance(item, str), user_mentions_raw)
)
- # Room mention is only true if the value is exactly true.
- room_mention = mentions.get("room") is True
evaluator = PushRuleEvaluator(
- _flatten_dict(event),
+ _flatten_dict(
+ event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ ),
has_mentions,
user_mentions,
- room_mention,
room_member_count,
sender_power_level,
notification_levels,
@@ -416,6 +424,8 @@ class BulkPushRuleEvaluator:
self._related_event_match_enabled,
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
+ self.hs.config.experimental.msc3758_exact_event_match,
+ self.hs.config.experimental.msc3966_exact_event_property_contains,
)
users = rules_by_user.keys()
@@ -489,16 +499,22 @@ RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
+def _is_simple_value(value: Any) -> bool:
+ return isinstance(value, (bool, str)) or type(value) is int or value is None
+
+
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
- result: Optional[Dict[str, str]] = None,
-) -> Dict[str, str]:
+ result: Optional[Dict[str, JsonValue]] = None,
+ *,
+ msc3783_escape_event_match_key: bool = False,
+) -> Dict[str, JsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
flatten it into a single layer dictionary by combining the keys & sub-keys.
- Any (non-dictionary), non-string value is dropped.
+ String, integer, boolean, null or lists of those values are kept. All others are dropped.
Transforms:
@@ -521,11 +537,24 @@ def _flatten_dict(
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, str):
- result[".".join(prefix + [key])] = value.lower()
+ if msc3783_escape_event_match_key:
+ # Escape periods in the key with a backslash (and backslashes with an
+ # extra backslash). This is since a period is used as a separator between
+ # nested fields.
+ key = key.replace("\\", "\\\\").replace(".", "\\.")
+
+ if _is_simple_value(value):
+ result[".".join(prefix + [key])] = value
+ elif isinstance(value, (list, tuple)):
+ result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)]
elif isinstance(value, Mapping):
# do not set `room_version` due to recursion considerations below
- _flatten_dict(value, prefix=(prefix + [key]), result=result)
+ _flatten_dict(
+ value,
+ prefix=(prefix + [key]),
+ result=result,
+ msc3783_escape_event_match_key=msc3783_escape_event_match_key,
+ )
# `room_version` should only ever be set when looking at the top level of an event
if (
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index cc0528bd8e..424854efbe 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -370,15 +370,23 @@ class ReplicationDataHandler:
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info(
- "Waiting for repl stream %r to reach %s (%s)",
+ "Waiting for repl stream %r to reach %s (%s); currently at: %s",
stream_name,
position,
instance_name,
+ current_position,
)
try:
await make_deferred_yieldable(deferred)
except defer.TimeoutError:
- logger.error("Timed out waiting for stream %s", stream_name)
+ logger.error(
+ "Timed out waiting for repl stream %r to reach %s (%s)"
+ "; currently at: %s",
+ stream_name,
+ position,
+ instance_name,
+ self._streams[stream_name].current_token(instance_name),
+ )
return
logger.info(
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0d072c42a7..c134ccfb3d 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,7 +15,7 @@
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
- PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$")
+ PATTERNS = [
+ *admin_patterns("/media/delete$"),
+ # This URL kept around for legacy reasons, it is undesirable since it
+ # overlaps with the DeleteMediaByID servlet.
+ *admin_patterns("/media/(?P[^/]*)/delete$"),
+ ]
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
@@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_POST(
- self, request: SynapseRequest, server_name: str
+ self, request: SynapseRequest, server_name: Optional[str] = None
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- if self.server_name != server_name:
+ # This check is useless, we keep it for the legacy endpoint only.
+ if server_name is not None and self.server_name != server_name:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
logging.info(
@@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer)
ProtectMediaByID(hs).register(http_server)
UnprotectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
- DeleteMediaByID(hs).register(http_server)
+ # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as
+ # their URL routes overlap.
DeleteMediaByDateSize(hs).register(http_server)
+ DeleteMediaByID(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b9dca8ef3a..0c0bf540b9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet):
if not await self._store.get_user_by_id(user_id):
raise NotFoundError("User not found")
- global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+ global_data = await self._store.get_global_account_data_for_user(user_id)
+ by_room_data = await self._store.get_room_account_data_for_user(user_id)
return HTTPStatus.OK, {
"account_data": {
"global": global_data,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 232f3a976d..662f5bf762 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -737,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
- requester.user.to_string(),
- {
- "address": body.address,
- "medium": body.medium,
- "id_server": body.id_server,
- },
+ requester.user.to_string(), body.medium, body.address, body.id_server
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index e2b410cf32..9be5860221 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -16,7 +16,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -39,6 +39,7 @@ class ReportEventRestServlet(RestServlet):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
+ self._event_handler = self.hs.get_event_handler()
async def on_POST(
self, request: SynapseRequest, room_id: str, event_id: str
@@ -61,6 +62,14 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ event = await self._event_handler.get_event(
+ requester.user, room_id, event_id, show_redacted=False
+ )
+ if event is None:
+ raise NotFoundError(
+ "Unable to report event: it does not exist or you aren't able to see it."
+ )
+
await self.store.add_event_report(
room_id=room_id,
event_id=event_id,
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index f7081f638e..4e7ffdb555 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ """
+ Retrieve the version information about the most current backup version (if any)
+
+ It takes out an exclusive lock on this user's room_key backups, to ensure
+ clients only upload to the current backup.
+
+ Returns 404 if the given version does not exist.
+
+ GET /room_keys/version HTTP/1.1
+ {
+ "version": "12345",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
+ }
+ """
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ user_id = requester.user.to_string()
+
+ try:
+ info = await self.e2e_room_keys_handler.get_version_info(user_id)
+ except SynapseError as e:
+ if e.code == 404:
+ raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
+ return 200, info
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Create a new backup version for this user's room_keys with the given
@@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$")
+ PATTERNS = client_patterns("/room_keys/version/(?P[^/]+)$")
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet):
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
async def on_GET(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Retrieve the version information about a given version of the user's
- room_keys backup. If the version part is missing, returns info about the
- most current backup version (if any)
+ room_keys backup.
It takes out an exclusive lock on this user's room_key backups, to ensure
clients only upload to the current backup.
@@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet):
return 200, info
async def on_DELETE(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Delete the information about a given version of the user's
- room_keys backup. If the version part is missing, deletes the most
- current backup version (if any). Doesn't delete the actual room data.
+ room_keys backup. Doesn't delete the actual room data.
DELETE /room_keys/version/12345 HTTP/1.1
HTTP/1.1 200 OK
{}
"""
- if version is None:
- raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
-
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
@@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet):
return 200, {}
async def on_PUT(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Update the information about a given version of the user's room_keys backup.
@@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet):
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- if version is None:
- raise SynapseError(
- 400, "No version specified to update", Codes.MISSING_PARAM
- )
-
await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {}
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index ca638755c7..dde08417a4 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -34,7 +34,9 @@ class TagListServlet(RestServlet):
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags")
+ PATTERNS = client_patterns(
+ "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags$"
+ )
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index d30878f704..6e035afcce 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,6 +16,7 @@
import logging
import os
import urllib
+from abc import ABC, abstractmethod
from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
@@ -284,13 +285,14 @@ async def respond_with_responder(
finish_request(request)
-class Responder:
+class Responder(ABC):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
+ @abstractmethod
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
@@ -300,11 +302,12 @@ class Responder:
Returns:
Resolves once the response has finished being written
"""
+ raise NotImplementedError()
- def __enter__(self) -> None:
+ def __enter__(self) -> None: # noqa: B027
pass
- def __exit__(
+ def __exit__( # noqa: B027
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..db25848744 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -46,10 +46,9 @@ from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
+ from synapse.rest.media.v1.storage_provider import StorageProvider
from synapse.server import HomeServer
- from .storage_provider import StorageProviderWrapper
-
logger = logging.getLogger(__name__)
@@ -68,7 +67,7 @@ class MediaStorage:
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
- storage_providers: Sequence["StorageProviderWrapper"],
+ storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.reactor = hs.get_reactor()
@@ -360,7 +359,7 @@ class ReadableFileWrapper:
clock: Clock
path: str
- async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
+ async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/server.py b/synapse/server.py
index 9d6d268f49..e5a3475247 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,7 +21,7 @@
import abc
import functools
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
@@ -144,10 +144,10 @@ if TYPE_CHECKING:
from synapse.handlers.saml import SamlHandler
-T = TypeVar("T", bound=Callable[..., Any])
+T = TypeVar("T")
-def cache_in_self(builder: T) -> T:
+def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]:
"""Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
returning if so. If not, calls the given function and sets `self.foo` to it.
@@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T:
building = [False]
@functools.wraps(builder)
- def _get(self):
+ def _get(self: "HomeServer") -> T:
try:
return getattr(self, depname)
except AttributeError:
@@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T:
return dep
- # We cast here as we need to tell mypy that `_get` has the same signature as
- # `builder`.
- return cast(T, _get)
+ return _get
class HomeServer(metaclass=abc.ABCMeta):
@@ -829,6 +827,7 @@ class HomeServer(metaclass=abc.ABCMeta):
hs=self,
host=self.config.redis.redis_host,
port=self.config.redis.redis_port,
+ dbid=self.config.redis.redis_dbid,
password=self.config.redis.redis_password,
reconnect=True,
)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..4dc25df67e 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+ EventContext,
+ UnpersistedEventContext,
+ UnpersistedEventContextBase,
+)
from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
@@ -222,7 +226,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
- self, room_id: str, latest_event_ids: List[str]
+ self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -262,31 +266,31 @@ class StateHandler:
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
- async def compute_event_context(
+ async def calculate_context_info(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: Optional[bool] = None,
- ) -> EventContext:
- """Build an EventContext structure for a non-outlier event.
+ state_group_before_event: Optional[int] = None,
+ ) -> UnpersistedEventContextBase:
+ """
+ Calulates the contents of an unpersisted event context, other than the current
+ state group (which is either provided or calculated when the event context is persisted)
- (for an outlier, call EventContext.for_outlier directly)
-
- This works out what the current state should be for the event, and
- generates a new state group if necessary.
-
- Args:
- event:
- state_ids_before_event: The event ids of the state before the event if
- it can't be calculated from existing events. This is normally
- only specified when receiving an event from federation where we
- don't have the prev events, e.g. when backfilling.
- partial_state:
- `True` if `state_ids_before_event` is partial and omits non-critical
- membership events.
- `False` if `state_ids_before_event` is the full state.
- `None` when `state_ids_before_event` is not provided. In this case, the
- flag will be calculated based on `event`'s prev events.
+ state_ids_before_event:
+ The event ids of the full state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling or when the event
+ is being created for batch persisting.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ state_group_before_event:
+ the current state group at the time of event, if known
Returns:
The event context.
@@ -294,7 +298,6 @@ class StateHandler:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
-
assert not event.internal_metadata.is_outlier()
#
@@ -306,17 +309,6 @@ class StateHandler:
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- # .. though we need to get a state group for it.
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=state_ids_before_event,
- )
- )
-
# the partial_state flag must be provided
assert partial_state is not None
else:
@@ -345,6 +337,7 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
# complete state here.
+
entry = await self.resolve_state_groups_for_events(
event.room_id,
event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
#
if not event.is_state():
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
+ state_group_after_event=state_group_before_event,
state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
#
- # otherwise, we'll need to create a new state group for after the event
+ # otherwise, we'll need to set up creating a new state group for after the event
#
key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
delta_ids = {key: event.event_id}
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
+ return UnpersistedEventContext(
+ storage=self._storage_controllers,
+ state_group_before_event=state_group_before_event,
+ state_group_after_event=None,
+ state_delta_due_to_event=delta_ids,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
+ partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group_after_event,
- state_group_before_event=state_group_before_event,
- state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
+ async def compute_event_context(
+ self,
+ event: EventBase,
+ state_ids_before_event: Optional[StateMap[str]] = None,
+ partial_state: Optional[bool] = None,
+ ) -> EventContext:
+ """Build an EventContext structure for a non-outlier event.
+
+ (for an outlier, call EventContext.for_outlier directly)
+
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
+
+ Args:
+ event:
+ state_ids_before_event: The event ids of the state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ entry:
+ A state cache entry for the resolved state across the prev events. We may
+ have already calculated this, so if it's available pass it in
+ Returns:
+ The event context.
+
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
+ """
+
+ unpersisted_context = await self.calculate_context_info(
+ event=event,
+ state_ids_before_event=state_ids_before_event,
partial_state=partial_state,
)
- async def compute_event_context_for_batched(
- self,
- event: EventBase,
- state_ids_before_event: StateMap[str],
- current_state_group: int,
- ) -> EventContext:
- """
- Generate an event context for an event that has not yet been persisted to the
- database. Intended for use with events that are created to be persisted in a batch.
- Args:
- event: the event the context is being computed for
- state_ids_before_event: a state map consisting of the state ids of the events
- created prior to this event.
- current_state_group: the current state group before the event.
- """
- state_group_before_event_prev_group = None
- deltas_to_state_group_before_event = None
-
- state_group_before_event = current_state_group
-
- # if the event is not state, we are set
- if not event.is_state():
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
- state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- partial_state=False,
- )
-
- # otherwise, we'll need to create a new state group for after the event
- key = (event.type, event.state_key)
-
- if state_ids_before_event is not None:
- replaces = state_ids_before_event.get(key)
-
- if replaces and replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
-
- delta_ids = {key: event.event_id}
-
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
- )
-
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group_after_event,
- state_group_before_event=state_group_before_event,
- state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- partial_state=False,
- )
+ return await unpersisted_context.persist(event)
@measure_func()
async def resolve_state_groups_for_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
+ db_pool: DatabasePool
+
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 52efd4a171..9d7a8a792f 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Callable,
@@ -23,7 +24,6 @@ from typing import (
List,
Mapping,
Optional,
- Set,
Tuple,
)
@@ -527,7 +527,7 @@ class StateStorageController:
)
return state_map.get(key)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@@ -584,7 +584,7 @@ class StateStorageController:
async def get_users_in_room_with_profiles(
self, room_id: str
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""
Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..feaa6cdd07 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -499,6 +499,7 @@ class DatabasePool:
"""
_TXN_ID = 0
+ engine: BaseDatabaseEngine
def __init__(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..95567826f2 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
@cached()
- async def get_account_data_for_user(
+ async def get_global_account_data_for_user(
self, user_id: str
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Mapping[str, JsonDict]:
"""
- Get all the client account_data for a user.
+ Get all the global client account_data for a user.
If experimental MSC3391 support is enabled, any entries with an empty
content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
+
Returns:
- A 2-tuple of a dict of global account_data and a dict mapping from
- room_id string to per room account_data dicts.
+ The global account_data.
"""
- def get_account_data_for_user_txn(
+ def get_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Dict[str, JsonDict]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
- global_account_data = {
+ return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
+ return await self.db_pool.runInteraction(
+ "get_global_account_data_for_user", get_global_account_data_for_user
+ )
+
+ @cached()
+ async def get_room_account_data_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
+ """
+ Get all of the per-room client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
+
+ Args:
+ user_id: The user to get the account_data for.
+
+ Returns:
+ A dict mapping from room_id string to per-room account_data dicts.
+ """
+
+ def get_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_data[row["account_data_type"]] = db_to_json(row["content"])
- return global_account_data, by_room
+ return by_room
return await self.db_pool.runInteraction(
- "get_account_data_for_user", get_account_data_for_user_txn
+ "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
)
@cached(num_args=2, max_entries=5000, tree=True)
@@ -215,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- async def get_updated_account_data_for_user(
+ async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a that's changed for a user
+ ) -> Dict[str, JsonDict]:
+ """Get all the global account_data that's changed for a user.
Args:
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
+
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A dict of global account_data.
"""
- def get_updated_account_data_for_user_txn(
+ def get_updated_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- sql = (
- "SELECT account_data_type, content FROM account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
+ ) -> Dict[str, JsonDict]:
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+ return {row[0]: db_to_json(row[1]) for row in txn}
- sql = (
- "SELECT room_id, account_data_type, content FROM room_account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return {}
+ return await self.db_pool.runInteraction(
+ "get_updated_global_account_data_for_user",
+ get_updated_global_account_data_for_user,
+ )
+
+ async def get_updated_room_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Get all the room account_data that's changed for a user.
+
+ Args:
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
+
+ Returns:
+ A dict mapping from room_id string to per room account_data dicts.
+ """
+
+ def get_updated_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
- return global_account_data, account_data_by_room
+ return account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return {}, {}
+ return {}
return await self.db_pool.runInteraction(
- "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+ "get_updated_room_account_data_for_user",
+ get_updated_room_account_data_for_user_txn,
)
@cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self.get_global_account_data_by_type_for_user.invalidate(
(row.user_id, row.data_type)
)
- self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_global_account_data_for_user.invalidate((row.user_id,))
+ self.get_room_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(user_id, account_data_type)
)
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.prefill(
(user_id, account_data_type), {}
)
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room_and_type, (user_id,)
)
self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
+ txn, self.get_global_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_account_data_for_user, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -100,6 +101,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
+ ("device_lists_changes_converted_stream_position", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -201,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
- async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ async def count_devices_by_users(
+ self, user_ids: Optional[Collection[str]] = None
+ ) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -212,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
def count_devices_by_users_txn(
- txn: LoggingTransaction, user_ids: List[str]
+ txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@@ -745,42 +749,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@trace
@cancellable
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, Optional[str]]]
- ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+ self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list: List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
+ user_ids: users which should have all device IDs returned
+ user_and_device_ids: List of (user_id, device_ids)
Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
- user_ids = {user_id for user_id, _ in query_list}
- user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+ user_map = await self.get_device_list_last_stream_id_for_remotes(
+ list(unique_user_ids)
+ )
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
- user_ids
+ unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
- user_ids_not_in_cache = user_ids - user_ids_in_cache
+ user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
- results: Dict[str, Dict[str, JsonDict]] = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
-
- if device_id:
- device = await self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
+ # First fetch all the users which all devices are to be returned.
+ results: Dict[str, Mapping[str, JsonDict]] = {}
+ for user_id in user_ids:
+ if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
+ # Then fetch all device-specific requests, but skip users we've already
+ # fetched all devices for.
+ device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ for user_id, device_id in user_and_device_ids:
+ if user_id in user_ids_in_cache and user_id not in user_ids:
+ device = await self._get_cached_user_device(user_id, device_id)
+ device_specific_results.setdefault(user_id, {})[device_id] = device
+ results.update(device_specific_results)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
import attr
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
)
@cached(max_entries=5000)
- async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..2c2d145666 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
Union,
cast,
@@ -260,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
+ "get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""Returns the fallback key types that have an unused key.
Args:
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# The `Optional` comes from the `@cachedList` decorator.
- return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+ return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
- self._get_e2e_cross_signing_signatures_txn,
- result,
- from_user_id,
+ result = cast(
+ Dict[str, Optional[Mapping[str, JsonDict]]],
+ await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ ),
)
return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ca780cca36 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+ async def get_max_depth_of(
+ self, event_ids: Collection[str]
+ ) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
+ @staticmethod
+ def empty() -> "RoomNotifCounts":
+ return _EMPTY_ROOM_NOTIF_COUNTS
+
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..7996cbb557 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
import itertools
import logging
from collections import OrderedDict
-from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -26,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Set,
Tuple,
)
@@ -36,7 +34,7 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
)
-class PartialStateConflictError(SynapseError):
- """An internal error raised when attempting to persist an event with partial state
- after the room containing the event has been un-partial stated.
-
- This error should be handled by recomputing the event context and trying again.
-
- This error has an HTTP status code so that it can be transported over replication.
- It should not be exposed to clients.
- """
-
- def __init__(self) -> None:
- super().__init__(
- HTTPStatus.CONFLICT,
- msg="Cannot persist partial state event in un-partial stated room",
- errcode=Codes.UNKNOWN,
- )
-
-
@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
# The set of event_ids to return. This includes all soft-failed events
# and their prev events.
- existing_prevs = set()
+ existing_prevs: Set[str] = set()
def _get_prevs_before_rejected_txn(
txn: LoggingTransaction, batch: Collection[str]
@@ -571,7 +551,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -865,7 +845,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..584536111d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
import attr
@@ -29,7 +29,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
- cast(Dict[str, Sequence[str]], event_to_auth_chain),
+ cast(Dict[str, StrCollection], event_to_auth_chain),
)
return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+ async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9213ce0b5a..9c41d01e13 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -420,12 +420,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_push_actions",
"event_search",
"event_failed_pull_attempts",
+ # Note: the partial state tables have foreign keys between each other, and to
+ # `events` and `rooms`. We need to delete from them in the right order.
"partial_state_events",
+ "partial_state_rooms_servers",
+ "partial_state_rooms",
"events",
"federation_inbound_events_staging",
"local_current_membership",
- "partial_state_rooms_servers",
- "partial_state_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..dddf49c2d5 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
cast,
)
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients.
Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[JsonDict]:
+ ) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, Sequence[JsonDict]]:
if not room_ids:
return {}
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..9a55e17624 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
- async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..fa3266c081 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+ ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None
@cached()
- async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+ async def get_aggregation_groups_for_event(
+ self, event_id: str
+ ) -> Sequence[JsonDict]:
raise NotImplementedError()
@cachedList(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count
@cached(max_entries=100000, iterable=True)
- async def get_users_in_room(self, room_id: str) -> List[str]:
+ async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- def get_user_in_room_with_profile(
- self, room_id: str, user_id: str
- ) -> Dict[str, ProfileInfo]:
+ def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
raise NotImplementedError()
@cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000)
- async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+ async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(iterable=True)
- async def get_local_users_in_room(self, room_id: str) -> List[str]:
+ async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)
- user_who_share_room = set()
+ user_who_share_room: Set[str] = set()
for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
from unpaddedbase64 import encode_base64
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+ def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
)
async def get_event_reference_hashes(
self, event_ids: Collection[str]
- ) -> Dict[str, Dict[str, bytes]]:
+ ) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events.
Args:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+ async def get_tags_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user.
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, Dict[str, JsonDict]]:
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..30af4b3b6c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,9 +16,9 @@ import logging
import re
from typing import (
TYPE_CHECKING,
- Dict,
Iterable,
List,
+ Mapping,
Optional,
Sequence,
Set,
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@@ -918,11 +918,19 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
- results = _parse_words(search_term)
+ escaped_words = []
+ for word in _parse_words(search_term):
+ # Postgres tsvector and tsquery quoting rules:
+ # words potentially containing punctuation should be quoted
+ # and then existing quotes and backslashes should be doubled
+ # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
- both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
- exact = " & ".join("%s" % (result,) for result in results)
- prefix = " & ".join("%s:*" % (result,) for result in results)
+ quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+ escaped_words.append(f"'{quoted_word}'")
+
+ both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+ exact = " & ".join("%s" % (word,) for word in escaped_words)
+ prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
return both, exact, prefix
@@ -944,6 +952,14 @@ def _parse_words(search_term: str) -> List[str]:
if USE_ICU:
return _parse_words_with_icu(search_term)
+ return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+ """
+ Break down search term into words, when we don't have ICU available.
+ See: `_parse_words`
+ """
return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a182e8a098..d1ccb7390a 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -25,7 +25,7 @@ try:
except ImportError:
class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef]
- def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
+ def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
)
@@ -36,7 +36,7 @@ try:
except ImportError:
class Sqlite3Engine(BaseDatabaseEngine): # type: ignore[no-redef]
- def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
+ def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- sqlite3 module is not installed"
)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3acdb39da7..6c335a9315 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
from synapse.storage.types import Cursor
@@ -108,9 +108,14 @@ def prepare_database(
# so we start one before running anything. This ensures that any upgrades
# are either applied completely, or not at all.
#
- # (psycopg2 automatically starts a transaction as soon as we run any statements
- # at all, so this is redundant but harmless there.)
- cur.execute("BEGIN TRANSACTION")
+ # psycopg2 does not automatically start transactions when in autocommit mode.
+ # While it is technically harmless to nest transactions in postgres, doing so
+ # results in a warning in Postgres' logs per query. And we'd rather like to
+ # avoid doing that.
+ if isinstance(database_engine, Sqlite3Engine) or (
+ isinstance(database_engine, PostgresEngine) and db_conn.autocommit
+ ):
+ cur.execute("BEGIN TRANSACTION")
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 0031df1e06..56a0048539 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import TracebackType
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
from typing_extensions import Protocol
@@ -112,15 +123,35 @@ class DBAPI2Module(Protocol):
# extends from this hierarchy. See
# https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
# https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
- Warning: Type[Exception]
- Error: Type[Exception]
+ #
+ # Note: rather than
+ # x: T
+ # we write
+ # @property
+ # def x(self) -> T: ...
+ # which expresses that the protocol attribute `x` is read-only. The mypy docs
+ # https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected
+ # explain why this is necessary for safety. TL;DR: we shouldn't be able to write
+ # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 .
+ @property
+ def Warning(self) -> Type[Exception]:
+ ...
+
+ @property
+ def Error(self) -> Type[Exception]:
+ ...
# Errors are divided into `InterfaceError`s (something went wrong in the database
# driver) and `DatabaseError`s (something went wrong in the database). These are
# both subclasses of `Error`, but we can't currently express this in type
# annotations due to https://github.com/python/mypy/issues/8397
- InterfaceError: Type[Exception]
- DatabaseError: Type[Exception]
+ @property
+ def InterfaceError(self) -> Type[Exception]:
+ ...
+
+ @property
+ def DatabaseError(self) -> Type[Exception]:
+ ...
# Everything below is a subclass of `DatabaseError`.
@@ -128,7 +159,9 @@ class DBAPI2Module(Protocol):
# - An integer was too big for its data type.
# - An invalid date time was provided.
# - A string contained a null code point.
- DataError: Type[Exception]
+ @property
+ def DataError(self) -> Type[Exception]:
+ ...
# Roughly: something went wrong in the database, but it's not within the application
# programmer's control. Examples:
@@ -138,28 +171,45 @@ class DBAPI2Module(Protocol):
# - A serialisation failure occurred.
# - The database ran out of resources, such as storage, memory, connections, etc.
# - The database encountered an error from the operating system.
- OperationalError: Type[Exception]
+ @property
+ def OperationalError(self) -> Type[Exception]:
+ ...
# Roughly: we've given the database data which breaks a rule we asked it to enforce.
# Examples:
# - Stop, criminal scum! You violated the foreign key constraint
# - Also check constraints, non-null constraints, etc.
- IntegrityError: Type[Exception]
+ @property
+ def IntegrityError(self) -> Type[Exception]:
+ ...
# Roughly: something went wrong within the database server itself.
- InternalError: Type[Exception]
+ @property
+ def InternalError(self) -> Type[Exception]:
+ ...
# Roughly: the application did something silly that needs to be fixed. Examples:
# - We don't have permissions to do something.
# - We tried to create a table with duplicate column names.
# - We tried to use a reserved name.
# - We referred to a column that doesn't exist.
- ProgrammingError: Type[Exception]
+ @property
+ def ProgrammingError(self) -> Type[Exception]:
+ ...
# Roughly: we've tried to do something that this database doesn't support.
- NotSupportedError: Type[Exception]
+ @property
+ def NotSupportedError(self) -> Type[Exception]:
+ ...
- def connect(self, **parameters: object) -> Connection:
+ # We originally wrote
+ # def connect(self, *args, **kwargs) -> Connection: ...
+ # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory
+ # positional argument. We can't make that part of the signature though, because
+ # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
+ # the following slightly unusual workaround.
+ @property
+ def connect(self) -> Callable[..., Connection]:
...
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index c6c8a0315c..8a48ffc48d 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Tuple, TypeVar
from synapse.types import StrCollection, UserID
@@ -22,7 +22,8 @@ K = TypeVar("K")
R = TypeVar("R")
-class EventSource(Generic[K, R]):
+class EventSource(ABC, Generic[K, R]):
+ @abstractmethod
async def get_new_events(
self,
user: UserID,
@@ -32,4 +33,4 @@ class EventSource(Generic[K, R]):
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
- ...
+ raise NotImplementedError()
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index f82d1cfc29..33363867c4 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -69,6 +69,9 @@ StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
+# A "simple" (canonical) JSON value.
+SimpleJsonValue = Optional[Union[str, int, bool]]
+JsonValue = Union[List[SimpleJsonValue], Tuple[SimpleJsonValue, ...], SimpleJsonValue]
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 5d89ba94ad..2ee343d8a4 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listen_http(parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, GenericWorkerServer)
+ hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, SynapseHomeServer)
+ hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
-if TYPE_CHECKING:
- from twisted.internet.testing import MemoryReactor
-
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 0e8af2da54..1b9696748f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
@@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
@@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..6fb1f1bd6e 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
@attr.s
@@ -152,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
hs = self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
load_legacy_presence_router(hs)
@@ -418,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
#
# Thus we reset the mock, and try sending all online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -443,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
}
found_users = set()
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -470,7 +472,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@@ -491,7 +493,7 @@ def send_presence_update(
def sync_presence(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index d667dd27bf..35dd9a20df 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastores().main
- store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
+
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return int(500 * 1.23)
+
+ store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
# Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request(
@@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
- self.hs.get_datastores().main.get_current_state_event_counts = (
- lambda x: make_awaitable(600)
- )
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return 600
+
+ self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
room_1,
UserID.from_string(u1),
@@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index a986b15f0a..6381583c24 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.federation.sender import PerDestinationQueue, TransactionManager
+from synapse.federation.sender import (
+ FederationSender,
+ PerDestinationQueue,
+ TransactionManager,
+)
from synapse.federation.units import Edu, Transaction
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
return self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.pdus: List[JsonDict] = []
self.failed_pdus: List[JsonDict] = []
self.is_online = True
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
+ federation_sender = hs.get_federation_sender()
+ assert isinstance(federation_sender, FederationSender)
+ self.federation_sender = federation_sender
+
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
@@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# let's delete the federation transmission queue
# (this pretends we are starting up fresh.)
self.assertFalse(
- self.hs.get_federation_sender()
- ._per_destination_queues["host2"]
- .transmission_loop_running
+ self.federation_sender._per_destination_queues[
+ "host2"
+ ].transmission_loop_running
)
- del self.hs.get_federation_sender()._per_destination_queues["host2"]
+ del self.federation_sender._per_destination_queues["host2"]
# let's also clear any backoffs
self.get_success(
@@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# also fetch event 5 so we know its last_successful_stream_ordering later
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
+ assert event_2.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_2.internal_metadata.stream_ordering
@@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
def wake_destination_track(destination: str) -> None:
woken.append(destination)
- self.hs.get_federation_sender().wake_destination = wake_destination_track
+ self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
# cancel the pre-existing timer for _wake_destinations_needing_catchup
# this is because we are calling it manually rather than waiting for it
# to be called automatically
- self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+ assert self.federation_sender._catchup_after_startup_timer is not None
+ self.federation_sender._catchup_after_startup_timer.cancel()
self.get_success(
- self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ self.federation_sender._wake_destinations_needing_catchup(), by=5.0
)
# ASSERT (_wake_destinations_needing_catchup):
@@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)
)
+ assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 86e1236501..91694e4fca 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info2)
+ assert pulled_pdu_info2 is not None
remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event
@@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info)
+ assert pulled_pdu_info is not None
remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ddeffe1ad5..9e104fd96a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction
+from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
@@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
@@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return config
def test_send_receipts(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
def test_send_receipts_thread(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
# Create receipts for:
@@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(
+ spec=["send_transaction", "query_user_devices"]
+ )
return self.setup_test_homeserver(
- federation_transport_client=Mock(
- spec=["send_transaction", "query_user_devices"]
- ),
+ federation_transport_client=self.federation_transport_client,
)
def default_config(self) -> JsonDict:
@@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self.device_handler = device_handler
+
# whenever send_transaction is called, record the edu data
self.edus: List[JsonDict] = []
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
@@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.hs.get_federation_transport_client().query_user_devices.return_value = (
+ self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
@@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.get_success(
- self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
+ self.device_handler.device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
@@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
@@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.assertGreaterEqual(mock_send_txn.call_count, 3)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..1b97aaeed1 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -296,3 +296,30 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# we should now have an unused alg1 key
- res = self.get_success(
+ fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
- device_key_1 = {
+ device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
- device_key_2 = {
+ device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
+ device_handler = self.hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
e = self.get_failure(
- self.hs.get_device_handler().check_device_registered(
+ device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
- device_key = {
+ device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
- master_key = {
+ master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
- other_master_key = {
+ other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_client_keys = mock.Mock(
+ self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_user_devices = mock.Mock(
+ self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = (
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
- other_destinations=["hs2"],
+ other_destinations={"hs2"},
room_id="room_id",
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..69d384442f 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
- self._persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.info = self.get_success(
+ info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
- self.token_id = self.info.token_id
+ assert info is not None
+ self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index adddbd002f..951caaa6b3 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
- self.hs_patcher.start()
+ self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def tearDown(self) -> None:
- self.hs_patcher.stop()
+ self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
- self.hs.get_identity_handler().send_threepid_validation = Mock(
+ self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..1db99b3c00 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
- self.store.count_monthly_users = Mock(
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
- self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["federatable"])
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertEqual(room["join_rules"], "public")
# Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
- mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client = Mock(spec=["put_json"])
+ self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier,
- federation_http_client=mock_federation_client,
+ federation_http_client=self.mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+ def __init__(self, config: Any):
+ pass
+
+
class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler.
@@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
@@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile: ProfileInfo) -> bool:
+ async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile: ProfileInfo) -> bool:
+ async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
return True
@@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config(
+ {
+ "spam_checker": {
+ "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+ }
+ }
+ )
def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
@@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set())
- # Configure a spam checker.
- spam_checker = self.hs.get_spam_checker()
- # The spam checker doesn't need any methods, so create a bare object.
- spam_checker.spam_checker = object()
-
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
@@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self.hs.get_storage_controllers().persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(persistence.persist_event(event, context))
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index acfdcd3bca..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
@@ -63,7 +63,7 @@ from tests.http import (
get_test_ca_cert_file,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.utils import default_config
+from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
@@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
- client_protocol = client_factory.buildProtocol(dummy_address)
- assert isinstance(client_protocol, _WrappingProtocol)
+ # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
+ client_protocol = checked_cast(
+ _WrappingProtocol, client_factory.buildProtocol(dummy_address)
+ )
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index a817940730..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -43,6 +43,7 @@ from tests.http import (
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
+from tests.utils import checked_cast
logger = logging.getLogger(__name__)
@@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase):
assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
- s2c_transport = proxy_server.transport
- assert isinstance(s2c_transport, FakeTransport)
- client_protocol = s2c_transport.other
- assert isinstance(client_protocol, _WrappingProtocol)
- c2s_transport = client_protocol.transport
- assert isinstance(c2s_transport, FakeTransport)
+ # To help mypy out with the various Protocols and wrappers and mocks, we do
+ # some explicit casting. Without the casts, we hit the bug I reported at
+ # https://github.com/Shoobx/mypy-zope/issues/91 .
+ # We also double-checked these casts at runtime (test-time) because I found it
+ # quite confusing to deduce these types in the first place!
+ s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+ client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
def test_proxy_with_unsupported_scheme(self) -> None:
@@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
- )
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
- )
+ proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+from tests.utils import checked_cast
def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..3a1929691e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+ """Common properties of the two test case classes."""
+
+ module_api: ModuleApi
+
+ # These are all written by _test_sending_local_online_presence_to_local_user.
+ presence_receiver_id: str
+ presence_receiver_tok: str
+ presence_sender_id: str
+ presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -42,23 +59,23 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
- self.store = homeserver.get_datastores().main
- self.module_api = homeserver.get_module_api()
- self.event_creation_handler = homeserver.get_event_creation_handler()
- self.sync_handler = homeserver.get_sync_handler()
- self.auth_handler = homeserver.get_auth_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.module_api = hs.get_module_api()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.sync_handler = hs.get_sync_handler()
+ self.auth_handler = hs.get_auth_handler()
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
return self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
- def test_can_register_user(self):
+ def test_can_register_user(self) -> None:
"""Tests that an external module can register a user"""
# Register a new user
user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
- def test_can_register_admin_user(self):
+ def test_can_register_admin_user(self) -> None:
user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_admin(self):
+ def test_can_set_admin(self) -> None:
user_id = self.register_user(
"alice_wants_admin",
"1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_displayname(self):
+ def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname"
user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False
)
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+ assert found_userinfo is not None
self.get_success(
self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob")
- def test_get_userinfo_by_id(self):
+ def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
- def test_get_userinfo_by_id__no_user_found(self):
+ def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
- def test_get_user_ip_and_agents(self):
+ def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent.
info = self.get_success(
self.module_api.get_user_ip_and_agents(
- user_id, (last_seen_1 + last_seen_2) / 2
+ user_id, (last_seen_1 + last_seen_2) // 2
)
)
self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_get_user_ip_and_agents__no_user_found(self):
+ def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success(
self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_sending_events_into_room(self):
+ def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
- self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event,
)
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event
- content = {"body": "I am a puppet", "msgtype": "m.text"}
+ content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = {
"room_id": room_id,
"type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
- event: EventBase = self.get_success(
+ event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
)
self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception
)
- def test_public_rooms(self):
+ def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried.
"""
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
- def test_send_local_online_presence_to(self):
+ def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
- def test_send_local_online_presence_to_federation(self):
+ def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -397,7 +417,7 @@ class ModuleApiTestCase(HomeserverTestCase):
#
# Thus we reset the mock, and try sending online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -409,9 +429,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a presence update was sent as part of a federation transaction
found_update = False
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -431,7 +449,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
- def test_update_membership(self):
+ def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme")
@@ -554,14 +572,14 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
- def test_update_room_membership_remote_join(self):
+ def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
- self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+ self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to
@@ -582,7 +600,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
- def test_get_room_state(self):
+ def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
@@ -677,7 +695,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
- self.module_api.check_push_rule_actions({"foo": "bar"})
+ self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"])
@@ -756,7 +774,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias)
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
servlets = [
@@ -766,7 +784,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
@@ -774,18 +792,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def prepare(self, reactor, clock, homeserver):
- self.module_api = homeserver.get_module_api()
- self.sync_handler = homeserver.get_sync_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.module_api = hs.get_module_api()
+ self.sync_handler = hs.get_sync_handler()
- def test_send_local_online_presence_to_workers(self):
+ def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user(
- test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+ test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases:
@@ -852,6 +870,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to.
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate()
# Syncing again should result in no presence updates
@@ -868,6 +887,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers:
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api()
else:
module_api_to_use = test_case.module_api
@@ -875,12 +895,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it.
# Note that we're syncing via the master.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [test_case.presence_receiver_id],
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
@@ -897,7 +916,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -908,7 +927,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -936,12 +955,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 7567756135..199e3d7b70 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3758_exact_event_match": True,
+ "msc3952_intentional_mentions": True,
+ }
+ }
+ )
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3758_exact_event_match": True,
+ "msc3952_intentional_mentions": True,
+ }
+ }
+ )
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..7563f33fdc 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
+from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
)
+ assert user_tuple is not None
self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher.
@@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
)
)
- self.pusher = self.get_success(
+ pusher = self.get_success(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
@@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
data={},
)
)
+ assert isinstance(pusher, EmailPusher)
+ self.pusher = pusher
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
@@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
)
# check that the pusher for that email address has been deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
that notification.
"""
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
-from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
- def test_data(data: Optional[JsonDict]) -> None:
+ def test_data(data: Any) -> None:
self.get_failure(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased, again
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
device_id = user_tuple.device_id
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
)
# Look up the user info for the access token so we can compare the device ID.
- lookup_result: TokenLookupResult = self.get_success(
+ lookup_result = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert lookup_result is not None
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index da33423871..d320a12f96 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.synapse_rust.push import PushRuleEvaluator
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util import Clock
+from synapse.util.frozenutils import freeze
from tests import unittest
from tests.test_utils.event_injection import create_event, inject_member_event
@@ -48,17 +49,34 @@ class FlattenDictTestCase(unittest.TestCase):
input = {"foo": {"bar": "abc"}}
self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
+ # If a field has a dot in it, escape it.
+ input = {"m.foo": {"b\\ar": "abc"}}
+ self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
+ self.assertEqual(
+ {"m\\.foo.b\\\\ar": "abc"},
+ _flatten_dict(input, msc3783_escape_event_match_key=True),
+ )
+
def test_non_string(self) -> None:
- """Non-string items are dropped."""
+ """String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
"bar": 1,
"baz": None,
- "fuzz": [],
+ "fuzz": ["woo", True, 1, None, [], {}],
"boo": {},
}
- self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+ self.assertEqual(
+ {
+ "woo": "woo",
+ "foo": True,
+ "bar": 1,
+ "baz": None,
+ "fuzz": ["woo", True, 1, None],
+ },
+ _flatten_dict(input),
+ )
def test_event(self) -> None:
"""Events can also be flattened."""
@@ -78,9 +96,9 @@ class FlattenDictTestCase(unittest.TestCase):
)
expected = {
"content.msgtype": "m.text",
- "content.body": "hello world!",
+ "content.body": "Hello world!",
"content.format": "org.matrix.custom.html",
- "content.formatted_body": "hello world!
",
+ "content.formatted_body": "Hello world!
",
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
@@ -107,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
+ "content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -118,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
+ "content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -129,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
has_mentions: bool = False,
user_mentions: Optional[Set[str]] = None,
- room_mention: bool = False,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@@ -150,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
_flatten_dict(event),
has_mentions,
user_mentions or set(),
- room_mention,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -158,6 +176,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
+ msc3758_exact_event_match=True,
+ msc3966_exact_event_property_contains=True,
)
def test_display_name(self) -> None:
@@ -210,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.
- def test_room_mentions(self) -> None:
- """Check for room mentions."""
- condition = {"kind": "org.matrix.msc3952.is_room_mention"}
-
- # No room mention shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, None, None))
-
- # Room mention should match.
- evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # A room mention and user mention is valid.
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
- )
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -402,6 +401,178 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
+ def test_exact_event_match_string(self) -> None:
+ """Check that exact_event_match conditions work as expected for strings."""
+
+ # Test against a string value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": "foobaz"},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "FoobaZ"},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "test foobaz test"},
+ "values must exactly match",
+ )
+ value: Any
+ for value in (True, False, 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ frozendict.frozendict({"value": "foobaz"}),
+ "values should match on frozendicts",
+ )
+
+ def test_exact_event_match_boolean(self) -> None:
+ """Check that exact_event_match conditions work as expected for booleans."""
+
+ # Test against a True boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": True,
+ }
+ self._assert_matches(
+ condition,
+ {"value": True},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": False},
+ "incorrect values should not match",
+ )
+ for value in ("foobaz", 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # Test against a False boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": False,
+ }
+ self._assert_matches(
+ condition,
+ {"value": False},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": True},
+ "incorrect values should not match",
+ )
+ # Choose false-y values to ensure there's no type coercion.
+ for value in ("", 0, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_null(self) -> None:
+ """Check that exact_event_match conditions work as expected for null."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": None,
+ }
+ self._assert_matches(
+ condition,
+ {"value": None},
+ "exact value should match",
+ )
+ for value in ("foobaz", True, False, 1, 1.1, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_integer(self) -> None:
+ """Check that exact_event_match conditions work as expected for integers."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": 1,
+ }
+ self._assert_matches(
+ condition,
+ {"value": 1},
+ "exact value should match",
+ )
+ value: Any
+ for value in (1.1, -1, 0):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect values should not match",
+ )
+ for value in ("1", True, False, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_property_contains(self) -> None:
+ """Check that exact_event_property_contains conditions work as expected."""
+
+ condition = {
+ "kind": "org.matrix.msc3966.exact_event_property_contains",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz"]},
+ "exact value should match",
+ )
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz", "bugz"]},
+ "extra values should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": ["FoobaZ"]},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "foobaz"},
+ "does not search in a string",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ freeze({"value": ["foobaz"]}),
+ "values should match on frozendicts",
+ )
+
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 043dbe76af..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_event = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
e = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 38b5020ce0..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
room_id = self.helper.create_room_as("@bob:test")
# Mark the room as partial-stated.
self.get_success(
- self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+ self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
)
worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 68de5d1cc2..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
@@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
sends the proper position and RDATA).
"""
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bf927beb6a 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
# ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up.
+ assert store._cache_id_gen is not None
ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 89380e25b5..08703206a9 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import TypingWriterHandler
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
token = self.login("user3", "pass")
typing_handler = self.hs.get_typing_handler()
+ assert isinstance(typing_handler, TypingWriterHandler)
sent_on_1 = False
sent_on_2 = False
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 9345cfbeb2..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_dict is not None
token_id = user_dict.token_id
self.get_success(
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..db77a45ae3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+ self.url = "/_synapse/admin/v1/media/delete"
+ self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
@@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ @parameterized.expand([(True,), (False,)])
+ def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
"""
+ url = self.legacy_url if use_legacy_url else self.url
# upload and do not access
server_and_media_id = self._create_media()
@@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
now_ms = self.clock.time_msec()
channel = self.make_request(
"POST",
- self.url + "?before_ts=" + str(now_ms),
+ url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..f71ff46d87 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..f5b213219f 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage_controllers = self.hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage_controllers.persistence.persist_event(event, context))
+ context = self.get_success(unpersisted_context.persist(event))
+
+ self.get_success(persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- async def check_username(username: str) -> bool:
- if username == "allowed":
- return True
+ async def check_username(
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
+ inhibit_user_in_use_error: bool = False,
+ ) -> None:
+ if localpart == "allowed":
+ return
raise SynapseError(
400,
"User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
)
handler = self.hs.get_registration_handler()
- handler.check_username = check_username
+ handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None:
"""
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..e2ee1a1766 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
return {}
# Register a mock that will return the expected result depending on the remote.
- self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+ self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
# Check that we've got the correct response from the client-side endpoint.
self._test_status(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..a144610078 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -43,6 +43,9 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
super().__init__(hs)
self.recaptcha_attempts: List[Tuple[dict, str]] = []
+ def is_enabled(self) -> bool:
+ return True
+
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
@@ -1319,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..830762fd53 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False
+ self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.hs.is_mine = _is_mine
+ self.hs.is_mine = _is_mine # type: ignore[assignment]
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..67e16880e6 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = make_awaitable(None)
+ self.presence_handler = Mock(spec=PresenceHandler)
+ self.presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
federation_client=Mock(),
- presence_handler=presence_handler,
+ presence_handler=self.presence_handler,
)
return hs
@@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
@@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
+ self.assertEqual(self.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..4c561f9525 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self) -> None:
- self.hs.config.key.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = b"test"
self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+ self.hs.config.account_validity.account_validity_startup_job_max_delta = (
+ self.max_delta
+ )
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ assert res is not None
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
data = {"reason": None, "score": None}
self._assert_status(400, data)
+ def test_cannot_report_nonexistent_event(self) -> None:
+ """
+ Tests that we don't accept event reports for events which do not exist.
+ """
+ channel = self.make_request(
+ "POST",
+ f"rooms/{self.room_id}/report/$nonsenseeventid:test",
+ {"reason": "i am very sad"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(404, channel.code, msg=channel.result["body"])
+
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
+ assert isinstance(first_event_id, str)
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
+ assert isinstance(valid_event_id, str)
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Check that we can still access state events that were sent before the event that
# has been purged.
- self.get_event(room_id, create_event.event_id)
+ self.get_event(room_id, bool(create_event))
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True))
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNone(event)
return {}
- self.assertIsNotNone(event)
+ assert event is not None
time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..cfad182b2f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
)
event.internal_metadata.outlier = True
+ persistence = self._storage_controllers.persistence
+ assert persistence is not None
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock(
+ identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
side_effect=AssertionError("This should not get called")
)
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
event_source.get_new_events(
user=UserID.from_string(self.other_user_id),
from_key=0,
- limit=None,
+ limit=10,
room_ids=[room_id],
is_guest=False,
)
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..5fa3440691 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -425,7 +425,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.room_id, EventTypes.Tombstone, ""
)
)
- self.assertIsNotNone(tombstone_event)
+ assert tombstone_event is not None
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
# Check that the new room exists.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib import parse
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
-
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.fetches = []
+ self.fetches: List[
+ Tuple[
+ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ ignore_backoff: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
- def write_to(r):
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
- d = Deferred()
- d.addCallback(write_to)
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallback(write_to)
+ return make_deferred_yieldable(d_after_callback)
+ # Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.get_file = get_file
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Synapse should regenerate missing thumbnails.
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ assert info is not None
file_id = info["filesystem_id"]
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
},
{
"thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
},
],
- file_id=f"image{self.test_image.extension}",
+ file_id=f"image{self.test_image.extension.decode()}",
url_cache=None,
server_name=None,
)
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
self.config = config
self.api = api
+ @staticmethod
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
from tests.unittest import TestCase
class RegisterTestCase(TestCase):
- def test_success(self):
+ def test_success(self) -> None:
"""
The script will fetch a nonce, and then generate a MAC with it, and then
post that MAC.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
# sys.exit shouldn't have been called.
self.assertEqual(err_code, [])
- def test_failure_nonce(self):
+ def test_failure_nonce(self) -> None:
"""
If the script fails to fetch a nonce, it throws an error and quits.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 404
r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
self.assertIn("ERROR! Received 404 Not Found", out)
self.assertNotIn("Success!", out)
- def test_failure_post(self):
+ def test_failure_post(self) -> None:
"""
The script will fetch a nonce, and then if the final POST fails, will
report an error and quit.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
+ Any,
+ Awaitable,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
+ Sequence,
Tuple,
Type,
+ TypeVar,
Union,
+ cast,
)
from unittest.mock import Mock
import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConnector,
IConsumer,
IHostnameResolver,
+ IProducer,
IProtocol,
IPullProducer,
IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple,
ITransport,
)
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+R = TypeVar("R")
+P = ParamSpec("P")
+
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
"""
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
+
+ See twisted.web.http.HTTPChannel.
"""
site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed.
"""
- if not self.is_finished:
+ if not self.is_finished():
raise Exception("Request not yet completed")
return self.result["body"].decode("utf8")
@@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i)
return h
- def writeHeaders(self, version, code, reason, headers):
+ def writeHeaders(
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ ) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content: bytes) -> None:
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+ def write(self, data: bytes) -> None:
+ assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result:
self.result["body"] = b""
- self.result["body"] += content
+ self.result["body"] += data
+
+ def writeSequence(self, data: Iterable[bytes]) -> None:
+ for x in data:
+ self.write(x)
+
+ def loseConnection(self) -> None:
+ self.unregisterProducer()
+ self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
- def registerProducer( # type: ignore[override]
- self,
- producer: Union[IPullProducer, IPushProducer],
- streaming: bool,
- ) -> None:
- self._producer = producer
+ def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+ # TODO This should ensure that the IProducer is an IPushProducer or
+ # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+ # implement those, but doesn't declare it.
+ self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self.producerStreaming = streaming
def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None
+ def stopProducing(self) -> None:
+ if self._producer is not None:
+ self._producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ raise NotImplementedError()
+
+ def resumeProducing(self) -> None:
+ raise NotImplementedError()
+
def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886
- def getResourceFor(self, request):
+ def getResourceFor(self, request: Request) -> IResource:
return self._resource
def make_request(
- reactor,
+ reactor: MemoryReactorClock,
site: Union[Site, FakeSite],
method: Union[bytes, str],
path: Union[bytes, str],
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
A MemoryReactorClock that supports callFromThread.
"""
- def __init__(self):
+ def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
- self._udp = []
+ self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
- self._thread_callbacks: Deque[Callable[[], None]] = deque()
+ self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
- def getHostByName(self, name, timeout=None):
+ def getHostByName(
+ self, name: str, timeout: Optional[Sequence[int]] = None
+ ) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
- def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+ def listenUDP(
+ self,
+ port: int,
+ protocol: DatagramProtocol,
+ interface: str = "",
+ maxPacketSize: int = 8196,
+ ) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
- def callFromThread(self, callback, *args, **kwargs):
+ def callFromThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
- cb = lambda: callback(*args, **kwargs)
+ cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a
# separate queue.
self._thread_callbacks.append(cb)
- def getThreadPool(self):
- return self.threadpool
+ def callInThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
+ raise NotImplementedError()
- def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+ def suggestThreadPoolSize(self, size: int) -> None:
+ raise NotImplementedError()
+
+ def getThreadPool(self) -> "threadpool.ThreadPool":
+ # Cast to match super-class.
+ return cast(threadpool.ThreadPool, self.threadpool)
+
+ def add_tcp_client_callback(
+ self, host: str, port: int, callback: Callable[[], None]
+ ) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+ def connectTCP(
+ self,
+ host: str,
+ port: int,
+ factory: ClientFactory,
+ timeout: float = 30,
+ bindAddress: Optional[Tuple[str, int]] = None,
+ ) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
- def advance(self, amount):
+ def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
+
+ See twisted.python.threadpool.ThreadPool
"""
- def __init__(self, reactor):
+ def __init__(self, reactor: IReactorTime):
self._reactor = reactor
- def start(self):
+ def start(self) -> None:
pass
- def stop(self):
+ def stop(self) -> None:
pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
+ def callInThreadWithCallback(
+ self,
+ onResult: Callable[[bool, Union[Failure, R]], None],
+ function: Callable[P, R],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> "Deferred[None]":
+ def _(res: Any) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
- d = Deferred()
+ d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
- interaction,
+ desc,
+ func,
*args,
**kwargs,
)
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
+ pool.runWithConnection = runWithConnection # type: ignore[assignment]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
+ pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True
# We've just changed the Databases to run DB transactions on the same
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -588,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
- other = attr.ib()
+ other: IProtocol
"""The Protocol object which will receive any data written to this transport.
-
- :type: twisted.internet.interfaces.IProtocol
"""
- _reactor = attr.ib()
+ _reactor: IReactorTime
"""Test reactor
-
- :type: twisted.internet.interfaces.IReactorTime
"""
- _protocol = attr.ib(default=None)
+ _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
- _peer_address: Optional[IAddress] = attr.ib(default=None)
+ _peer_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+ )
"""The value to be returned by getPeer"""
- _host_address: Optional[IAddress] = attr.ib(default=None)
+ _host_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+ )
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
- buffer = attr.ib(default=b"")
- producer = attr.ib(default=None)
- autoflush = attr.ib(default=True)
+ buffer: bytes = b""
+ producer: Optional[IPushProducer] = None
+ autoflush: bool = True
- def getPeer(self) -> Optional[IAddress]:
+ def getPeer(self) -> IAddress:
return self._peer_address
- def getHost(self) -> Optional[IAddress]:
+ def getHost(self) -> IAddress:
return self._host_address
- def loseConnection(self, reason=None):
+ def loseConnection(self) -> None:
if not self.disconnecting:
- logger.info("FakeTransport: loseConnection(%s)", reason)
+ logger.info("FakeTransport: loseConnection()")
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(reason)
+ self._protocol.connectionLost(
+ Failure(RuntimeError("FakeTransport.loseConnection()"))
+ )
# if we still have data to write, delay until that is done
if self.buffer:
@@ -640,38 +717,38 @@ class FakeTransport:
self.connected = False
self.disconnected = True
- def abortConnection(self):
+ def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(None)
+ self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
- def registerProducer(self, producer, streaming):
+ def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -683,7 +760,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def write(self, byt):
+ def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -695,11 +772,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
- def writeSequence(self, seq):
+ def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
- def flush(self, maxbytes=None):
+ def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -750,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
+ cleanup_func: Callable[[Callable[[], None]], None],
+ name: str = "test",
+ config: Optional[HomeServerConfig] = None,
+ reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
+ **kwargs: Any,
+) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -775,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = cast(ISynapseReactor, _reactor)
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
- config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -832,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
+ import psycopg2.extensions
+
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -839,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -867,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
- def cleanup():
+ def cleanup() -> None:
import psycopg2
+ import psycopg2.extensions
# Close all the db pools
- database._db_pool.close()
+ database_pool._db_pool.close()
dropped = False
@@ -886,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -918,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
+ async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().hash = hash
+ hs.get_auth_handler().hash = hash # type: ignore[assignment]
- async def validate_hash(p, h):
+ async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
- hs.get_auth_handler().validate_hash = validate_hash
+ hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ for module, module_config in hs.config.modules.loaded_modules:
+ module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..6540ed53f1 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,8 +14,12 @@
import os
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
tmpdir = self.mktemp()
os.mkdir(tmpdir)
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"room_name": "Server Notices",
}
- hs = self.setup_test_homeserver(config=config)
+ return self.setup_test_homeserver(config=config)
- return hs
-
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("bob", "abc123")
self.access_token = self.login("bob", "abc123")
- def test_get_sync_message(self):
+ def test_get_sync_message(self) -> None:
"""
When user consent server notices are enabled, a sync will cause a notice
to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..d2bfa53eda 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -33,7 +35,7 @@ from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -57,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.server_notices_sender = self.hs.get_server_notices_sender()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
@@ -86,39 +89,43 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True})
- def test_maybe_send_server_notice_disabled_hs(self):
+ def test_maybe_send_server_notice_disabled_hs(self) -> None:
"""If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@override_config({"limit_usage_by_mau": False})
- def test_maybe_send_server_notice_to_user_flag_off(self):
+ def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
"""If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
- self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
+ maybe_get_notice_room_for_user = (
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
+ )
+ assert isinstance(maybe_get_notice_room_for_user, Mock)
+ maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -126,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
@@ -134,11 +141,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
"""
Test when user does not have blocked notice, but should have one
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -147,11 +154,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
- def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
@@ -159,12 +166,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
"""
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
@@ -175,12 +182,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+ self,
+ ) -> None:
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -191,11 +200,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False})
- def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -207,26 +216,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+ self,
+ ) -> None:
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
- self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
+ self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
return_value=make_awaitable((True, []))
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -242,7 +253,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
c = super().default_config()
c["server_notices"] = {
"system_mxid_localpart": "server",
@@ -257,20 +268,22 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
- self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
self.event_source = self.hs.get_event_sources()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
+
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self.user_id = "@user_id:test"
- def test_server_notice_only_sent_once(self):
+ def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +319,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertEqual(count, 1)
- def test_no_invite_without_notice(self):
+ def test_no_invite_without_notice(self) -> None:
"""Tests that a user doesn't get invited to a server notices room without a
server notice being sent.
@@ -328,7 +341,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
m.assert_called_once_with(user_id)
- def test_invite_with_notice(self):
+ def test_invite_with_notice(self) -> None:
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
"""
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
event,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
"""
persist_events_store = self.hs.get_datastores().persist_events
+ assert persist_events_store is not None
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
+ assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state1 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids1 = self.get_success(context.get_current_state_ids())
+ assert state_ids1 is not None
+ state1 = set(state_ids1.values())
event, context = self.get_success(
event_handler.create_event(
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state2 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids2 = self.get_success(context.get_current_state_ids())
+ assert state_ids2 is not None
+ state2 = set(state_ids2.values())
# Delete the chain cover info.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..8fc7936ab0 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ persist_events = hs.get_datastores().persist_events
+ assert persist_events is not None
+ self.persist_events = persist_events
def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
import signedjson.types
import unpaddedbase64
-from twisted.internet.defer import Deferred
-
from synapse.storage.keys import FetchKeyResult
import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
- d = store.store_server_verify_keys(
- "from_server",
- 10,
- [
- ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys(
- [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+ res = self.get_success(
+ store.get_server_verify_keys(
+ [
+ ("server1", key_id_1),
+ ("server1", key_id_2),
+ ("server1", "ed25519:key3"),
+ ]
+ )
)
- res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_keys(
- "from_server",
- 0,
- [
- ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
- res = store.get_server_verify_keys([("srv1", key_id_1)])
- if isinstance(res, Deferred):
- res = self.successResultOf(res)
+ res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..d8f42c5d05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id, "m.room.create", ""
)
)
- self.assertIsNotNone(create_event)
+ assert create_event is not None
# Purge everything before this topological token
self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..12c17f1073 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+ assert persist_event_storage_controller is not None
+ self.persist_event_storage_controller = persist_event_storage_controller
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, context_1 = self.get_success(
+ event_1, unpersisted_context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
+ context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, context_2 = self.get_success(
+ event_2, unpersisted_context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
+
+ context_2 = self.get_success(unpersisted_context_2.persist(event_2))
self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, context = self.get_success(
+ redaction_event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(redaction_event))
+
self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
"content": {"msgtype": "m.text", "body": 2},
"room_id": room_id,
"sender": user_id,
- "depth": prev_event.depth + 1,
"prev_events": prev_event_ids,
"origin_server_ts": self.clock.time_msec(),
}
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_state_map,
for_verification=False,
),
- depth=event_dict["depth"],
+ depth=prev_event.depth + 1,
)
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
room_id=self.room_id,
from_key=self.from_token.room_key,
to_key=None,
- direction="f",
+ direction=Direction.FORWARDS,
limit=10,
event_filter=Filter(self.hs, filter),
)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0]
+ assert isinstance(database.engine, PostgresEngine)
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn:
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..2d169684cf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.background_updates import _BackgroundUpdateHandler
+from synapse.storage.databases.main import user_directory
+from synapse.storage.databases.main.user_directory import (
+ _parse_words_with_icu,
+ _parse_words_with_regex,
+)
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -42,7 +47,7 @@ ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
# The localpart isn't 'Bela' on purpose so we can test looking up display names.
-BELA = "@somenickname:a"
+BELA = "@somenickname:example.org"
class GetUserDirectoryTables:
@@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
class UserDirectoryStoreTestCase(HomeserverTestCase):
+ use_icu = False
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
@@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
+ self._restore_use_icu = user_directory.USE_ICU
+ user_directory.USE_ICU = self.use_icu
+
+ def tearDown(self) -> None:
+ user_directory.USE_ICU = self._restore_use_icu
+
def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
@@ -478,6 +491,26 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_start_of_user_id(self) -> None:
+ """Tests that a user can look up another user by searching for the start
+ of their user ID.
+ """
+ r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+
+
+class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase):
+ use_icu = True
+
+ if not icu:
+ skip = "Requires PyICU"
+
class UserDirectoryICUTestCase(HomeserverTestCase):
if not icu:
@@ -513,3 +546,31 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": ALICE, "display_name": display_name, "avatar_url": None},
)
+
+ def test_icu_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the ICU tokeniser.
+
+ Seems to depend on underlying version of ICU.
+ """
+
+ # Note: either tokenisation is fine, because Postgres actually splits
+ # words itself afterwards.
+ self.assertIn(
+ _parse_words_with_icu("lazy'fox jumped:over the.dog"),
+ (
+ # ICU 66 on Ubuntu 20.04
+ ["lazy'fox", "jumped", "over", "the", "dog"],
+ # ICU 70 on Ubuntu 22.04
+ ["lazy'fox", "jumped:over", "the.dog"],
+ ),
+ )
+
+ def test_regex_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the non-ICU tokeniser
+ """
+ self.assertEqual(
+ _parse_words_with_regex("lazy'fox jumped:over the.dog"),
+ ["lazy", "fox", "jumped", "over", "the", "dog"],
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dist = Distributor()
- def test_signal_dispatch(self):
+ def test_signal_dispatch(self) -> None:
self.dist.declare("alert")
observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- def test_signal_catch(self):
+ def test_signal_catch(self) -> None:
self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
- def test_signal_prereg(self):
+ def test_signal_prereg(self) -> None:
observer = Mock()
self.dist.observe("flare", observer)
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
- def test_signal_undeclared(self):
- def code():
+ def test_signal_undeclared(self) -> None:
+ def code() -> None:
self.dist.fire("notification")
self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore"""
- def __init__(self):
+ def __init__(self) -> None:
self._store: Dict[str, EventBase] = {}
- def add_event(self, event: EventBase):
+ def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event
- def add_events(self, events: Iterable[EventBase]):
+ def add_events(self, events: Iterable[EventBase]) -> None:
for event in events:
self._store[event.event_id] = event
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase):
- def test_rejected_auth_events(self):
+ def test_rejected_auth_events(self) -> None:
"""
Events that refer to rejected events in their auth events are rejected
"""
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
)
)
- def test_create_event_with_prev_events(self):
+ def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_duplicate_auth_events(self):
+ def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2)
)
- def test_unexpected_auth_events(self):
+ def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_random_users_cannot_send_state_before_first_pl(self):
+ def test_random_users_cannot_send_state_before_first_pl(self) -> None:
"""
Check that, before the first PL lands, the creator is the only user
that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_state_default_level(self):
+ def test_state_default_level(self) -> None:
"""
Check that users above the state_default level can send state and
those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_msc2432_alias_event(self):
+ def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
)
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
- def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+ def test_notifications(
+ self, room_version: RoomVersion, allow_modification: bool
+ ) -> None:
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
- def test_join_rules_public(self):
+ def test_join_rules_public(self) -> None:
"""
Test joining a public room.
"""
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
- def test_join_rules_invite(self):
+ def test_join_rules_invite(self) -> None:
"""
Test joining an invite only room.
"""
@@ -835,7 +837,7 @@ def _power_levels_event(
)
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..82dfd88b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Optional, Union
from unittest.mock import Mock
-from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
+from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
- def setUp(self):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
- self.reactor = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.reactor)
- self.homeserver = setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.http_client,
- clock=self.hs_clock,
- reactor=self.reactor,
- )
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
user_id = UserID("us", "test")
our_user = create_requester(user_id)
- room_creator = self.homeserver.get_room_creation_handler()
+ room_creator = self.hs.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)[0]["room_id"]
- self.store = self.homeserver.get_datastores().main
+ self.store = self.hs.get_datastores().main
# Figure out what the most recent event is
most_recent = self.get_success(
- self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
- self.room_id
- )
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -78,17 +73,23 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_federation_handler()
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ self.handler = self.hs.get_federation_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
- async def _check_event_auth(origin, event, context):
+ async def _check_event_auth(
+ origin: Optional[str], event: EventBase, context: EventContext
+ ) -> None:
pass
- federation_event_handler._check_event_auth = _check_event_auth
- self.client = self.homeserver.get_federation_client()
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
- lambda dest, pdus, **k: succeed(pdus)
- )
+ federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
+ self.client = self.hs.get_federation_client()
+
+ async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+ dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+ ) -> List[EventBase]:
+ return list(pdus)
+
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
# Send the join, it should return None (which is not an error)
self.assertEqual(
@@ -104,16 +105,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"$join:test.serv",
)
- def test_cant_hide_direct_ancestors(self):
+ def test_cant_hide_direct_ancestors(self) -> None:
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
- async def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryParams] = None,
+ ) -> Union[JsonDict, list]:
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
+ return {}
self.http_client.post_json = post_json
@@ -138,7 +148,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +168,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
- def test_retry_device_list_resync(self):
+ def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically.
"""
@@ -171,24 +181,27 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully.
- def query_user_devices(destination, user_id):
+ def query_user_devices(
+ destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
if user_id == remote_user_id:
self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
- store = self.homeserver.get_datastores().main
+ store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
- device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ device_list_updater = self.hs.get_device_handler().device_list_updater
+ assert isinstance(device_list_updater, DeviceListUpdater)
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
@@ -218,7 +231,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)
- def test_cross_signing_keys_retry(self):
+ def test_cross_signing_keys_retry(self) -> None:
"""Tests that resyncing a device list correctly processes cross-signing keys from
the remote server.
"""
@@ -227,8 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
# Register mock device list retrieval on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
@@ -252,7 +265,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Resync the device list.
- device_handler = self.homeserver.get_device_handler()
+ device_handler = self.hs.get_device_handler()
self.get_success(
device_handler.device_list_updater.user_device_resync(remote_user_id),
)
@@ -261,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
keys = self.get_success(
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
)
- self.assertTrue(remote_user_id in keys)
+ self.assertIn(remote_user_id, keys)
+ key = keys[remote_user_id]
+ assert key is not None
# Check that the master key is the one returned by the mock.
- master_key = keys[remote_user_id]["master"]
+ master_key = key["master"]
self.assertEqual(len(master_key["keys"]), 1)
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
self.assertTrue(remote_master_key in master_key["keys"].values())
# Check that the self-signing key is the one returned by the mock.
- self_signing_key = keys[remote_user_id]["self_signing"]
+ self_signing_key = key["self_signing"]
self.assertEqual(len(self_signing_key["keys"]), 1)
self.assertTrue(
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
@@ -279,7 +294,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
class StripUnsignedFromEventsTestCase(unittest.TestCase):
- def test_strip_unauthorized_unsigned_values(self):
+ def test_strip_unauthorized_unsigned_values(self) -> None:
event1 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -296,7 +311,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
# Make sure unauthorized fields are stripped from unsigned
self.assertNotIn("more warez", filtered_event.unsigned)
- def test_strip_event_maintains_allowed_fields(self):
+ def test_strip_event_maintains_allowed_fields(self) -> None:
event2 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -323,7 +338,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
self.assertIn("invite_room_state", filtered_event2.unsigned)
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
- def test_strip_event_removes_fields_based_on_event_type(self):
+ def test_strip_event_removes_fields_based_on_event_type(self) -> None:
event3 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths."""
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_simple_deny_mau(self):
+ def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_as_ignores_mau(self):
+ def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True)
- def test_allowed_after_a_month_mau(self):
+ def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
- def test_trial_delay(self):
+ def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
- def test_trial_users_cant_come_back(self):
+ def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
- def test_tracked_but_not_limited(self):
+ def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
}
)
- def test_as_trial_days(self):
+ def test_as_trial_days(self) -> None:
user_tokens: List[str] = []
- def advance_time_and_sync():
+ def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61)
for token in user_tokens:
self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
},
)
- def create_user(self, localpart, token=None, appservice=False):
+ def create_user(
+ self, localpart: str, token: Optional[str] = None, appservice: bool = False
+ ) -> str:
request_data = {
"username": localpart,
"password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
- def do_sync_for_user(self, token):
+ def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index cc1a98f1c4..3f899b0d91 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
If time doesn't move, don't error out.
"""
past_stats = [
- (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
]
stats: JsonDict = {}
self.get_success(phone_stats_home(self.hs, stats, past_stats))
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code."""
- def test_basic(self):
+ def test_basic(self) -> None:
result = sum_as_string(1, 2)
self.assertEqual("3", result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
- def test_advance_time(self):
+ def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
- def test_later(self):
+ def test_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1])
- def test_cancel_later(self):
+ def test_cancel_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_parse_rejects_empty_id(self):
+ def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
- def test_parse_rejects_missing_sigil(self):
+ def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
- def test_parse_rejects_missing_separator(self):
+ def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
- def test_validation_rejects_missing_domain(self):
+ def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
- def test_build(self):
+ def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
- def test_compare(self):
+ def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
- def test_build(self):
+ def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
- def test_validate(self):
+ def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
- def testPassThrough(self):
+ def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
- def testUpperCase(self):
+ def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
- def testSymbols(self):
+ def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
- def testLeadingUnderscore(self):
+ def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
- def testNonAscii(self):
+ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a generic event into a room
@@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- event, context = await hs.get_event_creation_handler().create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
+ context = await unpersisted_context.persist(event)
+
return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
# limitations under the License.
from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
# a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name
self.hiddens[input_name] = attr_dict["value"]
- def error(_, message):
+ def error(self, message: str) -> NoReturn:
raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
)
-def setup_logging():
+def setup_logging() -> None:
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..2801a950a8 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers()
+ assert self._storage_controllers.persistence is not None
+ self._persistence = self._storage_controllers.persistence
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -175,12 +177,11 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_room_member(
@@ -202,13 +203,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_message(
@@ -226,13 +226,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_outlier(self) -> EventBase:
@@ -250,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ self._persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
@@ -258,7 +257,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
- def test_out_of_band_invite_rejection(self):
+ def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index fa92dd94eb..b21e7f1221 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+ self.helper = RestHelper(
+ self.hs,
+ checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+ self.site,
+ getattr(self, "user_id", None),
+ )
if hasattr(self, "user_id"):
if self.hijack_auth:
@@ -315,7 +320,7 @@ class HomeserverTestCase(TestCase):
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
- async def get_requester(*args, **kwargs) -> Requester:
+ async def get_requester(*args: Any, **kwargs: Any) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
user_id=UserID.from_string(self.helper.auth_user_id),
@@ -361,7 +366,9 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
- def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
"""
Make and return a homeserver.
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9529ee53c8..5f8f4e76b5 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, failure_ts)
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
@@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual(
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..a0ac11bc5c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,7 +15,7 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Union, overload
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
@@ -335,6 +335,33 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, context = await event_creation_handler.create_new_client_event(builder)
+ event, unpersisted_context = await event_creation_handler.create_new_client_event(
+ builder
+ )
+ context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
+
+
+T = TypeVar("T")
+
+
+def checked_cast(type: Type[T], x: object) -> T:
+ """A version of typing.cast that is checked at runtime.
+
+ We have our own function for this for two reasons:
+
+ 1. typing.cast itself is deliberately a no-op at runtime, see
+ https://docs.python.org/3/library/typing.html#typing.cast
+ 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91
+ where mypy would erroneously consider `isinstance(x, type)` to be false in all
+ circumstances.
+
+ For this to make sense, `T` needs to be something that `isinstance` can check; see
+ https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance
+ https://docs.python.org/3/glossary.html#term-abstract-base-class
+ https://docs.python.org/3/library/typing.html#typing.runtime_checkable
+ for more details.
+ """
+ assert isinstance(x, type)
+ return x