diff --git a/CHANGES.md b/CHANGES.md index 5de819ea1e..5d4e80499e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,121 @@ +Synapse 1.21.0rc2 (2020-10-02) +============================== + +Features +-------- + +- Convert additional templates from inline HTML to Jinja2 templates. ([\#8444](https://github.com/matrix-org/synapse/issues/8444)) + +Bugfixes +-------- + +- Fix a regression in v1.21.0rc1 which broke thumbnails of remote media. ([\#8438](https://github.com/matrix-org/synapse/issues/8438)) +- Do not expose the experimental `uk.half-shot.msc2778.login.application_service` flow in the login API, which caused a compatibility problem with Element iOS. ([\#8440](https://github.com/matrix-org/synapse/issues/8440)) +- Fix malformed log line in new federation "catch up" logic. ([\#8442](https://github.com/matrix-org/synapse/issues/8442)) +- Fix DB query on startup for negative streams which caused long start up times. Introduced in [\#8374](https://github.com/matrix-org/synapse/issues/8374). ([\#8447](https://github.com/matrix-org/synapse/issues/8447)) + + +Synapse 1.21.0rc1 (2020-10-01) +============================== + +Features +-------- + +- Require the user to confirm that their password should be reset after clicking the email confirmation link. ([\#8004](https://github.com/matrix-org/synapse/issues/8004)) +- Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. ([\#8217](https://github.com/matrix-org/synapse/issues/8217)) +- Consolidate the SSO error template across all configuration. ([\#8248](https://github.com/matrix-org/synapse/issues/8248), [\#8405](https://github.com/matrix-org/synapse/issues/8405)) +- Add a configuration option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. ([\#8275](https://github.com/matrix-org/synapse/issues/8275), [\#8417](https://github.com/matrix-org/synapse/issues/8417)) +- Add experimental support for sharding event persister. ([\#8294](https://github.com/matrix-org/synapse/issues/8294), [\#8387](https://github.com/matrix-org/synapse/issues/8387), [\#8396](https://github.com/matrix-org/synapse/issues/8396), [\#8419](https://github.com/matrix-org/synapse/issues/8419)) +- Add the room topic and avatar to the room details admin API. ([\#8305](https://github.com/matrix-org/synapse/issues/8305)) +- Add an admin API for querying rooms where a user is a member. Contributed by @dklimpel. ([\#8306](https://github.com/matrix-org/synapse/issues/8306)) +- Add `uk.half-shot.msc2778.login.application_service` login type to allow appservices to login. ([\#8320](https://github.com/matrix-org/synapse/issues/8320)) +- Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. ([\#8345](https://github.com/matrix-org/synapse/issues/8345)) +- Add prometheus metrics for replication requests. ([\#8406](https://github.com/matrix-org/synapse/issues/8406)) +- Support passing additional single sign-on parameters to the client. ([\#8413](https://github.com/matrix-org/synapse/issues/8413)) +- Add experimental reporting of metrics on expensive rooms for state-resolution. ([\#8420](https://github.com/matrix-org/synapse/issues/8420)) +- Add experimental prometheus metric to track numbers of "large" rooms for state resolutiom. ([\#8425](https://github.com/matrix-org/synapse/issues/8425)) +- Add prometheus metrics to track federation delays. ([\#8430](https://github.com/matrix-org/synapse/issues/8430)) + + +Bugfixes +-------- + +- Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7. ([\#7124](https://github.com/matrix-org/synapse/issues/7124)) +- Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules. ([\#7796](https://github.com/matrix-org/synapse/issues/7796)) +- Fix a longstanding bug when storing a media file with an empty `upload_name`. ([\#7905](https://github.com/matrix-org/synapse/issues/7905)) +- Fix messages not being sent over federation until an event is sent into the same room. ([\#8230](https://github.com/matrix-org/synapse/issues/8230), [\#8247](https://github.com/matrix-org/synapse/issues/8247), [\#8258](https://github.com/matrix-org/synapse/issues/8258), [\#8272](https://github.com/matrix-org/synapse/issues/8272), [\#8322](https://github.com/matrix-org/synapse/issues/8322)) +- Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error. ([\#8236](https://github.com/matrix-org/synapse/issues/8236), [\#8435](https://github.com/matrix-org/synapse/issues/8435)) +- Upgrade minimum version of `canonicaljson` to version 1.4.0, to fix an unicode encoding issue. ([\#8262](https://github.com/matrix-org/synapse/issues/8262)) +- Fix longstanding bug which could lead to incomplete database upgrades on SQLite. ([\#8265](https://github.com/matrix-org/synapse/issues/8265)) +- Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error. ([\#8268](https://github.com/matrix-org/synapse/issues/8268)) +- Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied. ([\#8278](https://github.com/matrix-org/synapse/issues/8278)) +- Fix edge case where push could get delayed for a user until a later event was pushed. ([\#8287](https://github.com/matrix-org/synapse/issues/8287)) +- Fix fetching malformed events from remote servers. ([\#8324](https://github.com/matrix-org/synapse/issues/8324)) +- Fix `UnboundLocalError` from occuring when appservices send a malformed register request. ([\#8329](https://github.com/matrix-org/synapse/issues/8329)) +- Don't send push notifications to expired user accounts. ([\#8353](https://github.com/matrix-org/synapse/issues/8353)) +- Fix a regression in v1.19.0 with reactivating users through the admin API. ([\#8362](https://github.com/matrix-org/synapse/issues/8362)) +- Fix a bug where during device registration the length of the device name wasn't limited. ([\#8364](https://github.com/matrix-org/synapse/issues/8364)) +- Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2. ([\#8373](https://github.com/matrix-org/synapse/issues/8373)) +- Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. ([\#8374](https://github.com/matrix-org/synapse/issues/8374)) +- Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. ([\#8385](https://github.com/matrix-org/synapse/issues/8385)) +- Fix "Re-starting finished log context" warning when receiving an event we already had over federation. ([\#8398](https://github.com/matrix-org/synapse/issues/8398)) +- Fix incorrect handling of timeouts on outgoing HTTP requests. ([\#8400](https://github.com/matrix-org/synapse/issues/8400)) +- Fix a regression in v1.20.0 in the `synapse_port_db` script regarding the `ui_auth_sessions_ips` table. ([\#8410](https://github.com/matrix-org/synapse/issues/8410)) +- Remove unnecessary 3PID registration check when resetting password via an email address. Bug introduced in v0.34.0rc2. ([\#8414](https://github.com/matrix-org/synapse/issues/8414)) + + +Improved Documentation +---------------------- + +- Add `/_synapse/client` to the reverse proxy documentation. ([\#8227](https://github.com/matrix-org/synapse/issues/8227)) +- Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau). ([\#8375](https://github.com/matrix-org/synapse/issues/8375)) +- Improve description of `server_name` config option in `homserver.yaml`. ([\#8415](https://github.com/matrix-org/synapse/issues/8415)) + + +Deprecations and Removals +------------------------- + +- Drop support for `prometheus_client` older than 0.4.0. ([\#8426](https://github.com/matrix-org/synapse/issues/8426)) + + +Internal Changes +---------------- + +- Fix tests on distros which disable TLSv1.0. Contributed by @danc86. ([\#8208](https://github.com/matrix-org/synapse/issues/8208)) +- Simplify the distributor code to avoid unnecessary work. ([\#8216](https://github.com/matrix-org/synapse/issues/8216)) +- Remove the `populate_stats_process_rooms_2` background job and restore functionality to `populate_stats_process_rooms`. ([\#8243](https://github.com/matrix-org/synapse/issues/8243)) +- Clean up type hints for `PaginationConfig`. ([\#8250](https://github.com/matrix-org/synapse/issues/8250), [\#8282](https://github.com/matrix-org/synapse/issues/8282)) +- Track the latest event for every destination and room for catch-up after federation outage. ([\#8256](https://github.com/matrix-org/synapse/issues/8256)) +- Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. ([\#8257](https://github.com/matrix-org/synapse/issues/8257)) +- Switch to the JSON implementation from the standard library. ([\#8259](https://github.com/matrix-org/synapse/issues/8259)) +- Add type hints to `synapse.util.async_helpers`. ([\#8260](https://github.com/matrix-org/synapse/issues/8260)) +- Simplify tests that mock asynchronous functions. ([\#8261](https://github.com/matrix-org/synapse/issues/8261)) +- Add type hints to `StreamToken` and `RoomStreamToken` classes. ([\#8279](https://github.com/matrix-org/synapse/issues/8279)) +- Change `StreamToken.room_key` to be a `RoomStreamToken` instance. ([\#8281](https://github.com/matrix-org/synapse/issues/8281)) +- Refactor notifier code to correctly use the max event stream position. ([\#8288](https://github.com/matrix-org/synapse/issues/8288)) +- Use slotted classes where possible. ([\#8296](https://github.com/matrix-org/synapse/issues/8296)) +- Support testing the local Synapse checkout against the [Complement homeserver test suite](https://github.com/matrix-org/complement/). ([\#8317](https://github.com/matrix-org/synapse/issues/8317)) +- Update outdated usages of `metaclass` to python 3 syntax. ([\#8326](https://github.com/matrix-org/synapse/issues/8326)) +- Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. ([\#8330](https://github.com/matrix-org/synapse/issues/8330), [\#8377](https://github.com/matrix-org/synapse/issues/8377)) +- Use the `admin_patterns` helper in additional locations. ([\#8331](https://github.com/matrix-org/synapse/issues/8331)) +- Fix test logging to allow braces in log output. ([\#8335](https://github.com/matrix-org/synapse/issues/8335)) +- Remove `__future__` imports related to Python 2 compatibility. ([\#8337](https://github.com/matrix-org/synapse/issues/8337)) +- Simplify `super()` calls to Python 3 syntax. ([\#8344](https://github.com/matrix-org/synapse/issues/8344)) +- Fix bad merge from `release-v1.20.0` branch to `develop`. ([\#8354](https://github.com/matrix-org/synapse/issues/8354)) +- Factor out a `_send_dummy_event_for_room` method. ([\#8370](https://github.com/matrix-org/synapse/issues/8370)) +- Improve logging of state resolution. ([\#8371](https://github.com/matrix-org/synapse/issues/8371)) +- Add type annotations to `SimpleHttpClient`. ([\#8372](https://github.com/matrix-org/synapse/issues/8372)) +- Refactor ID generators to use `async with` syntax. ([\#8383](https://github.com/matrix-org/synapse/issues/8383)) +- Add `EventStreamPosition` type. ([\#8388](https://github.com/matrix-org/synapse/issues/8388)) +- Create a mechanism for marking tests "logcontext clean". ([\#8399](https://github.com/matrix-org/synapse/issues/8399)) +- A pair of tiny cleanups in the federation request code. ([\#8401](https://github.com/matrix-org/synapse/issues/8401)) +- Add checks on startup that PostgreSQL sequences are consistent with their associated tables. ([\#8402](https://github.com/matrix-org/synapse/issues/8402)) +- Do not include appservice users when calculating the total MAU for a server. ([\#8404](https://github.com/matrix-org/synapse/issues/8404)) +- Typing fixes for `synapse.handlers.federation`. ([\#8422](https://github.com/matrix-org/synapse/issues/8422)) +- Various refactors to simplify stream token handling. ([\#8423](https://github.com/matrix-org/synapse/issues/8423)) +- Make stream token serializing/deserializing async. ([\#8427](https://github.com/matrix-org/synapse/issues/8427)) + + Synapse 1.20.1 (2020-09-24) =========================== diff --git a/UPGRADE.rst b/UPGRADE.rst index 49e86e628f..5a68312217 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.22.0 +==================== + +ThirdPartyEventRules breaking changes +------------------------------------- + +This release introduces a backwards-incompatible change to modules making use of +``ThirdPartyEventRules`` in Synapse. If you make use of a module defined under the +``third_party_event_rules`` config option, please make sure it is updated to handle +the below change: + +The ``http_client`` argument is no longer passed to modules as they are initialised. Instead, +modules are expected to make use of the ``http_client`` property on the ``ModuleApi`` class. +Modules are now passed a ``module_api`` argument during initialisation, which is an instance of +``ModuleApi``. ``ModuleApi`` instances have a ``http_client`` property which acts the same as +the ``http_client`` argument previously passed to ``ThirdPartyEventRules`` modules. + Upgrading to v1.21.0 ==================== diff --git a/changelog.d/7124.bugfix b/changelog.d/7124.bugfix deleted file mode 100644 index 8fd177780d..0000000000 --- a/changelog.d/7124.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7. diff --git a/changelog.d/7658.feature b/changelog.d/7658.feature new file mode 100644 index 0000000000..fbf345988d --- /dev/null +++ b/changelog.d/7658.feature @@ -0,0 +1 @@ +Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch. diff --git a/changelog.d/7796.bugfix b/changelog.d/7796.bugfix deleted file mode 100644 index 65e5eb42a2..0000000000 --- a/changelog.d/7796.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules. diff --git a/changelog.d/8004.feature b/changelog.d/8004.feature deleted file mode 100644 index a91b75e0e0..0000000000 --- a/changelog.d/8004.feature +++ /dev/null @@ -1 +0,0 @@ -Require the user to confirm that their password should be reset after clicking the email confirmation link. \ No newline at end of file diff --git a/changelog.d/8208.misc b/changelog.d/8208.misc deleted file mode 100644 index e65da88c46..0000000000 --- a/changelog.d/8208.misc +++ /dev/null @@ -1 +0,0 @@ -Fix tests on distros which disable TLSv1.0. Contributed by @danc86. diff --git a/changelog.d/8216.misc b/changelog.d/8216.misc deleted file mode 100644 index b38911b0e5..0000000000 --- a/changelog.d/8216.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify the distributor code to avoid unnecessary work. diff --git a/changelog.d/8217.feature b/changelog.d/8217.feature deleted file mode 100644 index 899cbf14ef..0000000000 --- a/changelog.d/8217.feature +++ /dev/null @@ -1 +0,0 @@ -Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8227.doc b/changelog.d/8227.doc deleted file mode 100644 index 4a43015a83..0000000000 --- a/changelog.d/8227.doc +++ /dev/null @@ -1 +0,0 @@ -Add `/_synapse/client` to the reverse proxy documentation. diff --git a/changelog.d/8230.bugfix b/changelog.d/8230.bugfix deleted file mode 100644 index 532d0e22fe..0000000000 --- a/changelog.d/8230.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8236.bugfix b/changelog.d/8236.bugfix deleted file mode 100644 index 6f04871015..0000000000 --- a/changelog.d/8236.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error. diff --git a/changelog.d/8243.misc b/changelog.d/8243.misc deleted file mode 100644 index f7375d32d3..0000000000 --- a/changelog.d/8243.misc +++ /dev/null @@ -1 +0,0 @@ -Remove the 'populate_stats_process_rooms_2' background job and restore functionality to 'populate_stats_process_rooms'. \ No newline at end of file diff --git a/changelog.d/8247.bugfix b/changelog.d/8247.bugfix deleted file mode 100644 index 532d0e22fe..0000000000 --- a/changelog.d/8247.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8248.feature b/changelog.d/8248.feature deleted file mode 100644 index f3c4a74bc7..0000000000 --- a/changelog.d/8248.feature +++ /dev/null @@ -1 +0,0 @@ -Consolidate the SSO error template across all configuration. diff --git a/changelog.d/8250.misc b/changelog.d/8250.misc deleted file mode 100644 index b6896a9300..0000000000 --- a/changelog.d/8250.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up type hints for `PaginationConfig`. diff --git a/changelog.d/8256.misc b/changelog.d/8256.misc deleted file mode 100644 index bf0ba76730..0000000000 --- a/changelog.d/8256.misc +++ /dev/null @@ -1 +0,0 @@ -Track the latest event for every destination and room for catch-up after federation outage. diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc deleted file mode 100644 index 47ac583eb4..0000000000 --- a/changelog.d/8257.misc +++ /dev/null @@ -1 +0,0 @@ -Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. diff --git a/changelog.d/8258.bugfix b/changelog.d/8258.bugfix deleted file mode 100644 index 532d0e22fe..0000000000 --- a/changelog.d/8258.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8259.misc b/changelog.d/8259.misc deleted file mode 100644 index a26779a664..0000000000 --- a/changelog.d/8259.misc +++ /dev/null @@ -1 +0,0 @@ -Switch to the JSON implementation from the standard library. diff --git a/changelog.d/8260.misc b/changelog.d/8260.misc deleted file mode 100644 index 164eea8b59..0000000000 --- a/changelog.d/8260.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `synapse.util.async_helpers`. diff --git a/changelog.d/8261.misc b/changelog.d/8261.misc deleted file mode 100644 index bc91e9375c..0000000000 --- a/changelog.d/8261.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify tests that mock asynchronous functions. diff --git a/changelog.d/8262.bugfix b/changelog.d/8262.bugfix deleted file mode 100644 index 2b84927de3..0000000000 --- a/changelog.d/8262.bugfix +++ /dev/null @@ -1 +0,0 @@ -Upgrade canonicaljson to version 1.4.0 to fix an unicode encoding issue. diff --git a/changelog.d/8265.bugfix b/changelog.d/8265.bugfix deleted file mode 100644 index 981a836d21..0000000000 --- a/changelog.d/8265.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix logstanding bug which could lead to incomplete database upgrades on SQLite. diff --git a/changelog.d/8268.bugfix b/changelog.d/8268.bugfix deleted file mode 100644 index 4b15a60253..0000000000 --- a/changelog.d/8268.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error. diff --git a/changelog.d/8272.bugfix b/changelog.d/8272.bugfix deleted file mode 100644 index 532d0e22fe..0000000000 --- a/changelog.d/8272.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature deleted file mode 100644 index 17549c3df3..0000000000 --- a/changelog.d/8275.feature +++ /dev/null @@ -1 +0,0 @@ -Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. \ No newline at end of file diff --git a/changelog.d/8278.bugfix b/changelog.d/8278.bugfix deleted file mode 100644 index 50e40ca2a9..0000000000 --- a/changelog.d/8278.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied. diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc deleted file mode 100644 index 99f669001f..0000000000 --- a/changelog.d/8279.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `StreamToken` and `RoomStreamToken` classes. diff --git a/changelog.d/8281.misc b/changelog.d/8281.misc deleted file mode 100644 index 74357120a7..0000000000 --- a/changelog.d/8281.misc +++ /dev/null @@ -1 +0,0 @@ -Change `StreamToken.room_key` to be a `RoomStreamToken` instance. diff --git a/changelog.d/8282.misc b/changelog.d/8282.misc deleted file mode 100644 index b6896a9300..0000000000 --- a/changelog.d/8282.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up type hints for `PaginationConfig`. diff --git a/changelog.d/8287.bugfix b/changelog.d/8287.bugfix deleted file mode 100644 index 839781aa07..0000000000 --- a/changelog.d/8287.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix edge case where push could get delayed for a user until a later event was pushed. diff --git a/changelog.d/8288.misc b/changelog.d/8288.misc deleted file mode 100644 index c08a53a5ee..0000000000 --- a/changelog.d/8288.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor notifier code to correctly use the max event stream position. diff --git a/changelog.d/8292.feature b/changelog.d/8292.feature new file mode 100644 index 0000000000..6d0335e2c8 --- /dev/null +++ b/changelog.d/8292.feature @@ -0,0 +1 @@ +Allow `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory. \ No newline at end of file diff --git a/changelog.d/8294.feature b/changelog.d/8294.feature deleted file mode 100644 index b363e929ea..0000000000 --- a/changelog.d/8294.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for sharding event persister. diff --git a/changelog.d/8296.misc b/changelog.d/8296.misc deleted file mode 100644 index f593a5b347..0000000000 --- a/changelog.d/8296.misc +++ /dev/null @@ -1 +0,0 @@ -Use slotted classes where possible. diff --git a/changelog.d/8305.feature b/changelog.d/8305.feature deleted file mode 100644 index 862dfdf959..0000000000 --- a/changelog.d/8305.feature +++ /dev/null @@ -1 +0,0 @@ -Add the room topic and avatar to the room details admin API. diff --git a/changelog.d/8306.feature b/changelog.d/8306.feature deleted file mode 100644 index 5c23da4030..0000000000 --- a/changelog.d/8306.feature +++ /dev/null @@ -1 +0,0 @@ -Add an admin API for querying rooms where a user is a member. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8320.feature b/changelog.d/8320.feature deleted file mode 100644 index 475a5fe62d..0000000000 --- a/changelog.d/8320.feature +++ /dev/null @@ -1 +0,0 @@ -Add `uk.half-shot.msc2778.login.application_service` login type to allow appservices to login. diff --git a/changelog.d/8322.bugfix b/changelog.d/8322.bugfix deleted file mode 100644 index 532d0e22fe..0000000000 --- a/changelog.d/8322.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8324.bugfix b/changelog.d/8324.bugfix deleted file mode 100644 index 32788a9284..0000000000 --- a/changelog.d/8324.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix fetching events from remote servers that are malformed. diff --git a/changelog.d/8326.misc b/changelog.d/8326.misc deleted file mode 100644 index 985d2c027a..0000000000 --- a/changelog.d/8326.misc +++ /dev/null @@ -1 +0,0 @@ -Update outdated usages of `metaclass` to python 3 syntax. \ No newline at end of file diff --git a/changelog.d/8329.bugfix b/changelog.d/8329.bugfix deleted file mode 100644 index 2f71f1f4b9..0000000000 --- a/changelog.d/8329.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix UnboundLocalError from occuring when appservices send malformed register request. \ No newline at end of file diff --git a/changelog.d/8330.misc b/changelog.d/8330.misc deleted file mode 100644 index fbfdd52473..0000000000 --- a/changelog.d/8330.misc +++ /dev/null @@ -1 +0,0 @@ -Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/changelog.d/8331.misc b/changelog.d/8331.misc deleted file mode 100644 index 0e1bae20ef..0000000000 --- a/changelog.d/8331.misc +++ /dev/null @@ -1 +0,0 @@ -Use the `admin_patterns` helper in additional locations. diff --git a/changelog.d/8335.misc b/changelog.d/8335.misc deleted file mode 100644 index 7e0a4c7d83..0000000000 --- a/changelog.d/8335.misc +++ /dev/null @@ -1 +0,0 @@ -Fix test logging to allow braces in log output. \ No newline at end of file diff --git a/changelog.d/8337.misc b/changelog.d/8337.misc deleted file mode 100644 index 4daf272204..0000000000 --- a/changelog.d/8337.misc +++ /dev/null @@ -1 +0,0 @@ -Remove `__future__` imports related to Python 2 compatibility. \ No newline at end of file diff --git a/changelog.d/8344.misc b/changelog.d/8344.misc deleted file mode 100644 index 0b342d5137..0000000000 --- a/changelog.d/8344.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify `super()` calls to Python 3 syntax. diff --git a/changelog.d/8353.bugfix b/changelog.d/8353.bugfix deleted file mode 100644 index 45fc0adb8d..0000000000 --- a/changelog.d/8353.bugfix +++ /dev/null @@ -1 +0,0 @@ -Don't send push notifications to expired user accounts. diff --git a/changelog.d/8354.misc b/changelog.d/8354.misc deleted file mode 100644 index 1d33cde2da..0000000000 --- a/changelog.d/8354.misc +++ /dev/null @@ -1 +0,0 @@ -Fix bad merge from `release-v1.20.0` branch to `develop`. diff --git a/changelog.d/8362.bugfix b/changelog.d/8362.bugfix deleted file mode 100644 index 4e50067c87..0000000000 --- a/changelog.d/8362.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fixed a regression in v1.19.0 with reactivating users through the admin API. diff --git a/changelog.d/8364.bugfix b/changelog.d/8364.bugfix deleted file mode 100644 index 7b82cbc388..0000000000 --- a/changelog.d/8364.bugfix +++ /dev/null @@ -1,2 +0,0 @@ -Fix a bug where during device registration the length of the device name wasn't -limited. diff --git a/changelog.d/8369.feature b/changelog.d/8369.feature new file mode 100644 index 0000000000..542993110b --- /dev/null +++ b/changelog.d/8369.feature @@ -0,0 +1 @@ +Allow running background tasks in a separate worker process. diff --git a/changelog.d/8370.misc b/changelog.d/8370.misc deleted file mode 100644 index 1aaac1e0bf..0000000000 --- a/changelog.d/8370.misc +++ /dev/null @@ -1 +0,0 @@ -Factor out a `_send_dummy_event_for_room` method. diff --git a/changelog.d/8371.misc b/changelog.d/8371.misc deleted file mode 100644 index 6a54a9496a..0000000000 --- a/changelog.d/8371.misc +++ /dev/null @@ -1 +0,0 @@ -Improve logging of state resolution. diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc deleted file mode 100644 index a56e36de4b..0000000000 --- a/changelog.d/8372.misc +++ /dev/null @@ -1 +0,0 @@ -Add type annotations to `SimpleHttpClient`. diff --git a/changelog.d/8373.bugfix b/changelog.d/8373.bugfix deleted file mode 100644 index e9d66a2088..0000000000 --- a/changelog.d/8373.bugfix +++ /dev/null @@ -1 +0,0 @@ -Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2. \ No newline at end of file diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix deleted file mode 100644 index 155bc3404f..0000000000 --- a/changelog.d/8374.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. diff --git a/changelog.d/8375.doc b/changelog.d/8375.doc deleted file mode 100644 index d291fb92fa..0000000000 --- a/changelog.d/8375.doc +++ /dev/null @@ -1 +0,0 @@ -Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau). diff --git a/changelog.d/8377.misc b/changelog.d/8377.misc deleted file mode 100644 index fbfdd52473..0000000000 --- a/changelog.d/8377.misc +++ /dev/null @@ -1 +0,0 @@ -Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc deleted file mode 100644 index cb8318bf57..0000000000 --- a/changelog.d/8383.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor ID generators to use `async with` syntax. diff --git a/changelog.d/8385.bugfix b/changelog.d/8385.bugfix deleted file mode 100644 index c42502a8e0..0000000000 --- a/changelog.d/8385.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix deleted file mode 100644 index 24983a1e95..0000000000 --- a/changelog.d/8386.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. diff --git a/changelog.d/8387.feature b/changelog.d/8387.feature deleted file mode 100644 index b363e929ea..0000000000 --- a/changelog.d/8387.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for sharding event persister. diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc deleted file mode 100644 index aaaef88b66..0000000000 --- a/changelog.d/8388.misc +++ /dev/null @@ -1 +0,0 @@ -Add `EventStreamPosition` type. diff --git a/changelog.d/8432.misc b/changelog.d/8432.misc new file mode 100644 index 0000000000..01fdad4caf --- /dev/null +++ b/changelog.d/8432.misc @@ -0,0 +1 @@ +Check for unreachable code with mypy. diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc new file mode 100644 index 0000000000..05f8b5bbf4 --- /dev/null +++ b/changelog.d/8433.misc @@ -0,0 +1 @@ +Add unit test for event persister sharding. diff --git a/changelog.d/8443.misc b/changelog.d/8443.misc new file mode 100644 index 0000000000..633598e6b3 --- /dev/null +++ b/changelog.d/8443.misc @@ -0,0 +1 @@ +Configure `public_baseurl` when using demo scripts. diff --git a/changelog.d/8448.misc b/changelog.d/8448.misc new file mode 100644 index 0000000000..5ddda1803b --- /dev/null +++ b/changelog.d/8448.misc @@ -0,0 +1 @@ +Add SQL logging on queries that happen during startup. diff --git a/changelog.d/8450.misc b/changelog.d/8450.misc new file mode 100644 index 0000000000..4e04c523ab --- /dev/null +++ b/changelog.d/8450.misc @@ -0,0 +1 @@ +Speed up unit tests when using PostgreSQL. diff --git a/changelog.d/8452.misc b/changelog.d/8452.misc new file mode 100644 index 0000000000..8288d91c78 --- /dev/null +++ b/changelog.d/8452.misc @@ -0,0 +1 @@ +Remove redundant databae loads of stream_ordering for events we already have. diff --git a/changelog.d/8454.bugfix b/changelog.d/8454.bugfix new file mode 100644 index 0000000000..c06d490b6f --- /dev/null +++ b/changelog.d/8454.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where invalid ignored users in account data could break clients. diff --git a/changelog.d/8457.bugfix b/changelog.d/8457.bugfix new file mode 100644 index 0000000000..545b06d180 --- /dev/null +++ b/changelog.d/8457.bugfix @@ -0,0 +1 @@ +Fix a bug where backfilling a room with an event that was missing the `redacts` field would break. diff --git a/changelog.d/8462.doc b/changelog.d/8462.doc new file mode 100644 index 0000000000..cf84db6db7 --- /dev/null +++ b/changelog.d/8462.doc @@ -0,0 +1 @@ +Update the directions for using the manhole with coroutines. diff --git a/demo/start.sh b/demo/start.sh index 83396e5c33..f6b5ea137f 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -30,6 +30,8 @@ for port in 8080 8081 8082; do if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config + echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config + echo 'enable_registration: true' >> $DIR/etc/$port.config # Warning, this heredoc depends on the interaction of tabs and spaces. Please don't diff --git a/docs/manhole.md b/docs/manhole.md index 7375f5ad46..75b6ae40e0 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -35,9 +35,12 @@ This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other parts of the process. +Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`. + As a simple example, retrieving an event from the database: -``` ->>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org') +```pycon +>>> from twisted.internet import defer +>>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) > ``` diff --git a/docs/openid.md b/docs/openid.md index 70b37f858b..4873681999 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -238,13 +238,36 @@ Synapse config: ```yaml oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: '{{ user.preferred_username }}' - display_name_template: '{{ user.name }}' + enabled: true + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" +``` + +### GitLab + +1. Create a [new application](https://gitlab.com/profile/applications). +2. Add the `read_user` and `openid` scopes. +3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback` + +Synapse config: + +```yaml +oidc_config: + enabled: true + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/postgres.md b/docs/postgres.md index e71a1975d8..c30cc1fd8c 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -106,6 +106,17 @@ Note that the above may fail with an error about duplicate rows if corruption has already occurred, and such duplicate rows will need to be manually removed. +## Fixing inconsistent sequences error + +Synapse uses Postgres sequences to generate IDs for various tables. A sequence +and associated table can get out of sync if, for example, Synapse has been +downgraded and then upgraded again. + +To fix the issue shut down Synapse (including any and all workers) and run the +SQL command included in the error message. Once done Synapse should start +successfully. + + ## Tuning Postgres The default settings should be fine for most deployments. For larger diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fb04ff283d..7126ade2de 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -33,10 +33,23 @@ ## Server ## -# The domain name of the server, with optional explicit port. -# This is used by remote servers to connect to this server, -# e.g. matrix.org, localhost:8080, etc. -# This is also the last part of your UserID. +# The public-facing domain of the server +# +# The server_name name will appear at the end of usernames and room addresses +# created on this server. For example if the server_name was example.com, +# usernames on this server would be in the format @user:example.com +# +# In most cases you should avoid using a matrix specific subdomain such as +# matrix.example.com or synapse.example.com as the server_name for the same +# reasons you wouldn't use user@email.example.com as your email address. +# See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md +# for information on how to host Synapse on a subdomain while preserving +# a clean server_name. +# +# The server_name cannot be changed later so it is important to +# configure this correctly before you start Synapse. It should be all +# lowercase and may contain an explicit port. +# Examples: matrix.org, localhost:8080 # server_name: "SERVERNAME" @@ -616,6 +629,7 @@ acme: #tls_fingerprints: [{"sha256": ""}] +## Federation ## # Restrict federation to the following whitelist of domains. # N.B. we recommend also firewalling your federation listener to limit @@ -649,6 +663,17 @@ federation_ip_range_blacklist: - 'fe80::/64' - 'fc00::/7' +# Report prometheus metrics on the age of PDUs being sent to and received from +# the following domains. This can be used to give an idea of "delay" on inbound +# and outbound federation, though be aware that any delay can be due to problems +# at either end or with the intermediate network. +# +# By default, no domains are monitored in this way. +# +#federation_metrics_domains: +# - matrix.org +# - example.com + ## Caching ## @@ -1689,6 +1714,19 @@ oidc_config: # #skip_verification: true + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + # + # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included + # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + # + #user_profile_method: "userinfo_endpoint" + + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # @@ -1730,6 +1768,14 @@ oidc_config: # #display_name_template: "{{ user.given_name }} {{ user.last_name }}" + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{ user.birthdate }}" + # Enable CAS for registration and login. @@ -2458,6 +2504,11 @@ opentracing: # events: worker1 # typing: worker1 +# The worker that is used to run background tasks (e.g. cleaning up expired +# data). If not provided this defaults to the main process. +# +#run_background_tasks_on: worker1 + # Configuration for Redis when using workers. This *must* be enabled when # using workers (unless using old style direct TCP configuration). diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index abea432343..32b06aa2c5 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods: - This method must return a string, which is the unique identifier for the user. Commonly the ``sub`` claim of the response. * `map_user_attributes(self, userinfo, token)` - - This method should be async. + - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. @@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods: - Returns a dictionary with two keys: - localpart: A required string, used to generate the Matrix ID. - displayname: An optional string, the display name for the user. +* `get_extra_attributes(self, userinfo, token)` + - This method must be async. + - Arguments: + - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user + information from. + - `token` - A dictionary which includes information necessary to make + further requests to the OpenID provider. + - Returns a dictionary that is suitable to be serialized to JSON. This + will be returned as part of the response during a successful login. + + Note that care should be taken to not overwrite any of the parameters + usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). ### Default OpenID Mapping Provider diff --git a/docs/workers.md b/docs/workers.md index df0ac84d94..84a9759e34 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -243,6 +243,22 @@ for the room are in flight: ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ +Additionally, the following endpoints should be included if Synapse is configured +to use SSO (you only need to include the ones for whichever SSO provider you're +using): + + # OpenID Connect requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_synapse/oidc/callback$ + + # SAML requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_matrix/saml2/authn_response$ + + # CAS requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ + ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ + Note that a HTTP listener with `client` and `federation` resources must be configured in the `worker_listeners` option in the worker config. @@ -303,6 +319,23 @@ stream_writers: events: event_persister1 ``` +#### Background tasks + +There is also *experimental* support for moving background tasks to a separate +worker. Background tasks are run periodically or started via replication. Exactly +which tasks are configured to run depends on your Synapse configuration (e.g. if +stats is enabled). + +To enable this, the worker must have a `worker_name` and can be configured to run +background tasks. For example, to move background tasks to a dedicated worker, +the shared configuration would include: + +```yaml +run_background_tasks_on: background_worker +``` + +You might also wish to investigate the `update_user_directory` and +`media_instance_running_background_jobs` settings. ### `synapse.app.pusher` diff --git a/mypy.ini b/mypy.ini index 3fedfa9bbb..a7ffb81ef1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,7 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +warn_unreachable = True files = synapse/api, synapse/appservice, @@ -143,3 +144,6 @@ ignore_missing_imports = True [mypy-nacl.*] ignore_missing_imports = True + +[mypy-hiredis] +ignore_missing_imports = True diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh new file mode 100755 index 0000000000..3cde53f5c0 --- /dev/null +++ b/scripts-dev/complement.sh @@ -0,0 +1,22 @@ +#! /bin/bash -eu +# This script is designed for developers who want to test their code +# against Complement. +# +# It makes a Synapse image which represents the current checkout, +# then downloads Complement and runs it with that image. + +cd "$(dirname $0)/.." + +# Build the base Synapse image from the local checkout +docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile . + +# Download Complement +wget -N https://github.com/matrix-org/complement/archive/master.tar.gz +tar -xzf master.tar.gz +cd complement-master + +# Build the Synapse image from Complement, based on the above image we just built +docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles + +# Run the tests on the resulting image! +COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 684a518b8e..7e12f5440c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -145,6 +145,7 @@ IGNORED_TABLES = { # the sessions are transient anyway, so ignore them. "ui_auth_sessions", "ui_auth_sessions_credentials", + "ui_auth_sessions_ips", } @@ -488,7 +489,7 @@ class Porter(object): hs = MockHomeserver(self.hs_config) - with make_conn(db_config, engine) as db_conn: + with make_conn(db_config, engine, "portdb") as db_conn: engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index c66413f003..522244bb57 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -16,7 +16,7 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Union +from typing import List, Optional, Union, Type class RedisProtocol: def publish(self, channel: str, message: bytes): ... @@ -42,3 +42,21 @@ def lazyConnection( class SubscriberFactory: def buildProtocol(self, addr): ... + +class ConnectionHandler: ... + +class RedisFactory: + continueTrying: bool + handler: RedisProtocol + def __init__( + self, + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: Optional[int] = None, + convertNumbers: Optional[int] = True, + ): ... diff --git a/synapse/__init__.py b/synapse/__init__.py index e40b582bd5..500558bbdf 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.20.1" +__version__ = "1.21.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 46013cde15..592abd844b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -155,3 +155,8 @@ class EventContentFields: class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 + + +class AccountDataTypes: + DIRECT = "m.direct" + IGNORED_USER_LIST = "m.ignored_user_list" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index fb476ddaf5..f6f7b2bf42 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -28,6 +28,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext @@ -271,9 +272,19 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() + # Log when we start the shut down process. + hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", logger.info, "Shutting down..." + ) + setup_sentry(hs) setup_sdnotify(hs) + # If background tasks are running on the main process, start collecting the + # phone home stats. + if hs.config.run_background_tasks: + start_phone_stats_home(hs) + # We now freeze all allocated objects in the hopes that (almost) # everything currently allocated are things that will be used for the # rest of time. Doing so means less work each GC (hopefully). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 7d309b1bb0..f0d65d08d7 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -208,6 +208,7 @@ def start(config_options): # Explicitly disable background processes config.update_user_directory = False + config.run_background_tasks = False config.start_pushers = False config.send_federation = False diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c38413c893..fc5188ce95 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -128,11 +128,13 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import UserPresenceState from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -454,6 +456,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + StatsStore, UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, @@ -476,6 +479,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + ServerMetricsStore, SearchWorkerStore, BaseSlavedStore, ): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dff739e106..4ed4a2c253 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -17,14 +17,10 @@ import gc import logging -import math import os -import resource import sys from typing import Iterable -from prometheus_client import Gauge - from twisted.application import service from twisted.internet import defer, reactor from twisted.python.failure import Failure @@ -60,7 +56,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -334,20 +329,6 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -# Gauges to expose monthly active user control metrics -current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") -current_mau_by_service_gauge = Gauge( - "synapse_admin_mau_current_mau_by_service", - "Current MAU by service", - ["app_service"], -) -max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") -registered_reserved_users_mau_gauge = Gauge( - "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids", -) - - def setup(config_options): """ Args: @@ -389,8 +370,6 @@ def setup(config_options): except UpgradeDatabaseException as e: quit_with_error("Failed to upgrade database: %s" % (e,)) - hs.setup_master() - async def do_acme() -> bool: """ Reprovision an ACME certificate, if it's required. @@ -486,92 +465,6 @@ class SynapseService(service.Service): return self._port.stopListening() -# Contains the list of processes we will be monitoring -# currently either 0 or 1 -_stats_process = [] - - -async def phone_stats_home(hs, stats, stats_process=_stats_process): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - hs.start_time) - if uptime < 0: - uptime = 0 - - # - # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. - # - old = stats_process[0] - new = (now, resource.getrusage(resource.RUSAGE_SELF)) - stats_process[0] = new - - # Get RSS in bytes - stats["memory_rss"] = new[1].ru_maxrss - - # Get CPU time in % of a single core, not % of all cores - used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( - old[1].ru_utime + old[1].ru_stime - ) - if used_cpu_time == 0 or new[0] == old[0]: - stats["cpu_average"] = 0 - else: - stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) - - # - # General statistics - # - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = await hs.get_datastore().count_all_users() - - total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = await hs.get_datastore().count_daily_user_type() - for name, count in daily_user_type_results.items(): - stats["daily_user_type_" + name] = count - - room_count = await hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = await hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - - r30_results = await hs.get_datastore().count_r30_users() - for name, count in r30_results.items(): - stats["r30_users_" + name] = count - - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = hs.config.caches.global_factor - stats["event_cache_size"] = hs.config.caches.event_cache_size - - # - # Database version - # - - # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version - - logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) - try: - await hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) - - def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -597,81 +490,6 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) - clock = hs.get_clock() - - stats = {} - - def performance_stats_init(): - _stats_process.clear() - _stats_process.append( - (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) - ) - - def start_phone_stats_home(): - return run_as_background_process( - "phone_stats_home", phone_stats_home, hs, stats - ) - - def generate_user_daily_visit_stats(): - return run_as_background_process( - "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits - ) - - # Rather than update on per session basis, batch up the requests. - # If you increase the loop period, the accuracy of user_daily_visits - # table will decrease - clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) - - # monthly active user limiting functionality - def reap_monthly_active_users(): - return run_as_background_process( - "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users - ) - - clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) - reap_monthly_active_users() - - async def generate_monthly_active_users(): - current_mau_count = 0 - current_mau_count_by_service = {} - reserved_users = () - store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - current_mau_count = await store.get_monthly_active_count() - current_mau_count_by_service = ( - await store.get_monthly_active_count_by_service() - ) - reserved_users = await store.get_registered_reserved_users() - current_mau_gauge.set(float(current_mau_count)) - - for app_service, count in current_mau_count_by_service.items(): - current_mau_by_service_gauge.labels(app_service).set(float(count)) - - registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) - - def start_generate_monthly_active_users(): - return run_as_background_process( - "generate_monthly_active_users", generate_monthly_active_users - ) - - start_generate_monthly_active_users() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) - # End of monthly active user settings - - if hs.config.report_stats: - logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) - - # We need to defer this init for the cases that we daemonize - # otherwise the process ID we get is that of the non-daemon process - clock.call_later(0, performance_stats_init) - - # We wait 5 minutes to send the first set of stats as the server can - # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) - _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.soft_file_limit, diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py new file mode 100644 index 0000000000..2c8e14a8c0 --- /dev/null +++ b/synapse/app/phone_stats_home.py @@ -0,0 +1,202 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import resource +import sys + +from prometheus_client import Gauge + +from synapse.metrics.background_process_metrics import run_as_background_process + +logger = logging.getLogger("synapse.app.homeserver") + +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) +max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") +registered_reserved_users_mau_gauge = Gauge( + "synapse_admin_mau:registered_reserved_users", + "Registered users with reserved threepids", +) + + +async def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = await hs.get_datastore().count_all_users() + + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = await hs.get_datastore().count_daily_user_type() + for name, count in daily_user_type_results.items(): + stats["daily_user_type_" + name] = count + + room_count = await hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + + r30_results = await hs.get_datastore().count_r30_users() + for name, count in r30_results.items(): + stats["r30_users_" + name] = count + + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size + + # + # Database version + # + + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version + + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + await hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + +def start_phone_stats_home(hs): + """ + Start the background tasks which report phone home stats. + """ + clock = hs.get_clock() + + stats = {} + + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) + ) + + def start_phone_stats_home(): + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats + ) + + def generate_user_daily_visit_stats(): + return run_as_background_process( + "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits + ) + + # Rather than update on per session basis, batch up the requests. + # If you increase the loop period, the accuracy of user_daily_visits + # table will decrease + clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + + # monthly active user limiting functionality + def reap_monthly_active_users(): + return run_as_background_process( + "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users + ) + + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) + reap_monthly_active_users() + + async def generate_monthly_active_users(): + current_mau_count = 0 + current_mau_count_by_service = {} + reserved_users = () + store = hs.get_datastore() + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + current_mau_count = await store.get_monthly_active_count() + current_mau_count_by_service = ( + await store.get_monthly_active_count_by_service() + ) + reserved_users = await store.get_registered_reserved_users() + current_mau_gauge.set(float(current_mau_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) + max_mau_gauge.set(float(hs.config.max_mau_value)) + + def start_generate_monthly_active_users(): + return run_as_background_process( + "generate_monthly_active_users", generate_monthly_active_users + ) + + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + start_generate_monthly_active_users() + clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) + # End of monthly active user settings + + if hs.config.report_stats: + logger.info("Scheduling stats reporting for 3 hour intervals") + clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) + + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + + # We wait 5 minutes to send the first set of stats as the server can + # be quite busy the first few minutes + clock.call_later(5 * 60, start_phone_stats_home) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 05a66841c3..85f65da4d9 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -242,12 +242,11 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update( - { - "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), - } - ) + env.filters.update({"format_ts": _format_ts_filter}) + if self.public_baseurl: + env.filters.update( + {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} + ) for filename in filenames: # Load the template diff --git a/synapse/config/_util.py b/synapse/config/_util.py index cd31b1c3c9..c74969a977 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Iterable import jsonschema @@ -20,7 +20,9 @@ from synapse.config._base import ConfigError from synapse.types import JsonDict -def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None: +def validate_config( + json_schema: JsonDict, config: Any, config_path: Iterable[str] +) -> None: """Validates a config setting against a JsonSchema definition This can be used to validate a section of the config file against a schema diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 82f04d7966..cb00958165 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -28,6 +28,9 @@ class CaptchaConfig(Config): "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", ) + self.recaptcha_template = self.read_templates( + ["recaptcha.html"], autoescape=True + )[0] def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index fbddebeeab..6efa59b110 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,6 +89,8 @@ class ConsentConfig(Config): def read_config(self, config, **kwargs): consent_config = config.get("user_consent") + self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0] + if consent_config is None: return self.user_consent_version = str(consent_config["version"]) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 2c77d8f85b..ffd8fca54e 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -17,7 +17,8 @@ from typing import Optional from netaddr import IPSet -from ._base import Config, ConfigError +from synapse.config._base import Config, ConfigError +from synapse.config._util import validate_config class FederationConfig(Config): @@ -52,8 +53,18 @@ class FederationConfig(Config): "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) + federation_metrics_domains = config.get("federation_metrics_domains") or [] + validate_config( + _METRICS_FOR_DOMAINS_SCHEMA, + federation_metrics_domains, + ("federation_metrics_domains",), + ) + self.federation_metrics_domains = set(federation_metrics_domains) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ + ## Federation ## + # Restrict federation to the following whitelist of domains. # N.B. we recommend also firewalling your federation listener to limit # inbound federation traffic as early as possible, rather than relying @@ -85,4 +96,18 @@ class FederationConfig(Config): - '::1/128' - 'fe80::/64' - 'fc00::/7' + + # Report prometheus metrics on the age of PDUs being sent to and received from + # the following domains. This can be used to give an idea of "delay" on inbound + # and outbound federation, though be aware that any delay can be due to problems + # at either end or with the intermediate network. + # + # By default, no domains are monitored in this way. + # + #federation_metrics_domains: + # - matrix.org + # - example.com """ + + +_METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}} diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 556e291495..be65554524 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -92,5 +92,4 @@ class HomeServerConfig(RootConfig): TracerConfig, WorkerConfig, RedisConfig, - FederationConfig, ] diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..7597fbc864 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,8 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +160,19 @@ class OIDCConfig(Config): # #skip_verification: true + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + # + # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included + # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + # + #user_profile_method: "userinfo_endpoint" + + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # @@ -198,6 +213,14 @@ class OIDCConfig(Config): # If unset, no displayname will be set. # #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{{{ user.birthdate }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5ffbb934fe..d7e3690a32 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -187,6 +187,11 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime + # The success template used during fallback auth. + self.fallback_success_template = self.read_templates( + ["auth_success.html"], autoescape=True + )[0] + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( diff --git a/synapse/config/server.py b/synapse/config/server.py index 532b910470..ef6d70e3f8 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -641,10 +641,23 @@ class ServerConfig(Config): """\ ## Server ## - # The domain name of the server, with optional explicit port. - # This is used by remote servers to connect to this server, - # e.g. matrix.org, localhost:8080, etc. - # This is also the last part of your UserID. + # The public-facing domain of the server + # + # The server_name name will appear at the end of usernames and room addresses + # created on this server. For example if the server_name was example.com, + # usernames on this server would be in the format @user:example.com + # + # In most cases you should avoid using a matrix specific subdomain such as + # matrix.example.com or synapse.example.com as the server_name for the same + # reasons you wouldn't use user@email.example.com as your email address. + # See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md + # for information on how to host Synapse on a subdomain while preserving + # a clean server_name. + # + # The server_name cannot be changed later so it is important to + # configure this correctly before you start Synapse. It should be all + # lowercase and may contain an explicit port. + # Examples: matrix.org, localhost:8080 # server_name: "%(server_name)s" diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e368ea564d..ad37b93c02 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -18,7 +18,7 @@ import os import warnings from datetime import datetime from hashlib import sha256 -from typing import List +from typing import List, Optional from unpaddedbase64 import encode_base64 @@ -177,8 +177,8 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate = None - self.tls_private_key = None + self.tls_certificate = None # type: Optional[crypto.X509] + self.tls_private_key = None # type: Optional[crypto.PKey] def is_disk_cert_valid(self, allow_self_signed=True): """ @@ -226,12 +226,12 @@ class TlsConfig(Config): days_remaining = (expires_on - now).days return days_remaining - def read_certificate_from_disk(self, require_cert_and_key): + def read_certificate_from_disk(self, require_cert_and_key: bool): """ Read the certificates and private key from disk. Args: - require_cert_and_key (bool): set to True to throw an error if the certificate + require_cert_and_key: set to True to throw an error if the certificate and key file are not given """ if require_cert_and_key: @@ -471,7 +471,6 @@ class TlsConfig(Config): # or by checking matrix.org/federationtester/api/report?server_name=$host # #tls_fingerprints: [{"sha256": ""}] - """ # Lowercase the string representation of boolean values % { @@ -480,13 +479,13 @@ class TlsConfig(Config): } ) - def read_tls_certificate(self): + def read_tls_certificate(self) -> crypto.X509: """Reads the TLS certificate from the configured file, and returns it Also checks if it is self-signed, and warns if so Returns: - OpenSSL.crypto.X509: the certificate + The certificate """ cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) @@ -505,11 +504,11 @@ class TlsConfig(Config): return cert - def read_tls_private_key(self): + def read_tls_private_key(self) -> crypto.PKey: """Reads the TLS private key from the configured file, and returns it Returns: - OpenSSL.crypto.PKey: the private key + The private key """ private_key_path = self.tls_private_key_file logger.info("Loading TLS key from %s", private_key_path) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f23e42cdf9..57ab097eba 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -132,6 +132,19 @@ class WorkerConfig(Config): self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + # Whether this worker should run background tasks or not. + # + # As a note for developers, the background tasks guarded by this should + # be able to run on only a single instance (meaning that they don't + # depend on any in-memory state of a particular worker). + # + # No effort is made to ensure only a single instance of these tasks is + # running. + background_tasks_instance = config.get("run_background_tasks_on") or "master" + self.run_background_tasks = ( + self.worker_name is None and background_tasks_instance == "master" + ) or self.worker_name == background_tasks_instance + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Workers ## @@ -167,6 +180,11 @@ class WorkerConfig(Config): #stream_writers: # events: worker1 # typing: worker1 + + # The worker that is used to run background tasks (e.g. cleaning up expired + # data). If not provided this defaults to the main process. + # + #run_background_tasks_on: worker1 """ def read_arguments(self, args): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 42e4087a92..c04ad77cf9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -42,7 +42,6 @@ from synapse.api.errors import ( ) from synapse.logging.context import ( PreserveLoggingContext, - current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -233,8 +232,6 @@ class Keyring: """ try: - ctx = current_context() - # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -265,12 +262,8 @@ class Keyring: # if there are no more requests for this server, we can drop the lock. if not server_requests: - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name) - - # ... but not immediately, as that can cause stack explosions if - # we get a long queue of lookups. - self.clock.call_later(0, drop_server_lock, server_name) + logger.debug("Releasing key lookup lock on %s", server_name) + drop_server_lock(server_name) return res @@ -335,20 +328,32 @@ class Keyring: ) # look for any requests which weren't satisfied - with PreserveLoggingContext(): - for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) + while remaining_requests: + verify_request = remaining_requests.pop() + rq_str = ( + "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, ) + ) + + # If we run the errback immediately, it may cancel our + # loggingcontext while we are still in it, so instead we + # schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we + # has a massive queue of lookups waiting for this server). + self.clock.call_later( + 0, + verify_request.key_ready.errback, + SynapseError( + 401, + "Failed to find any key to satisfy %s" % (rq_str,), + Codes.UNAUTHORIZED, + ), + ) except Exception as err: # we don't really expect to get here, because any errors should already # have been caught and logged. But if we do, let's log the error and make @@ -410,10 +415,23 @@ class Keyring: # key was not valid at this point continue - with PreserveLoggingContext(): - verify_request.key_ready.callback( - (server_name, key_id, fetch_key_result.verify_key) - ) + # we have a valid key for this request. If we run the callback + # immediately, it may cancel our loggingcontext while we are still in + # it, so instead we schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we had + # a massive queue of lookups waiting for this server). + logger.debug( + "Found key %s:%s for %s", + server_name, + key_id, + verify_request.request_name, + ) + self.clock.call_later( + 0, + verify_request.key_ready.callback, + (server_name, key_id, fetch_key_result.verify_key), + ) completed.append(verify_request) break diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 8c907ad596..56f8dc9caf 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -446,6 +446,8 @@ def check_redaction( if room_version_obj.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) + if not isinstance(event.redacts, str): + return False redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: return True diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index bf800a3852..7a51d0a22f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty): class _EventInternalMetadata: - __slots__ = ["_dict"] + __slots__ = ["_dict", "stream_ordering"] def __init__(self, internal_metadata_dict: JsonDict): # we have to copy the dict, because it turns out that the same dict is # reused. TODO: fix that self._dict = dict(internal_metadata_dict) + # the stream ordering of this event. None, until it has been persisted. + self.stream_ordering = None # type: Optional[int] + outlier = DictProperty("outlier") # type: bool out_of_band_membership = DictProperty("out_of_band_membership") # type: bool send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str @@ -113,13 +116,12 @@ class _EventInternalMetadata: redacted = DictProperty("redacted") # type: bool txn_id = DictProperty("txn_id") # type: str token_id = DictProperty("token_id") # type: str - stream_ordering = DictProperty("stream_ordering") # type: int # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: str - after = DictProperty("after") # type: str + before = DictProperty("before") # type: RoomStreamToken + after = DictProperty("after") # type: RoomStreamToken order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9d5310851c..fed459198a 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Requester +from synapse.module_api import ModuleApi +from synapse.types import Requester, StateMap class ThirdPartyEventRules: @@ -38,7 +40,7 @@ class ThirdPartyEventRules: if module is not None: self.third_party_rules = module( - config=config, http_client=hs.get_simple_http_client() + config=config, module_api=ModuleApi(hs, hs.get_auth_handler()), ) async def check_event_allowed( @@ -106,6 +108,46 @@ class ThirdPartyEventRules: if self.third_party_rules is None: return True + state_events = await self._get_state_map_for_room(room_id) + + ret = await self.third_party_rules.check_threepid_can_be_invited( + medium, address, state_events + ) + return ret + + async def check_visibility_can_be_modified( + self, room_id: str, new_visibility: str + ) -> bool: + """Check if a room is allowed to be published to, or removed from, the public room + list. + + Args: + room_id: The ID of the room. + new_visibility: The new visibility state. Either "public" or "private". + + Returns: + True if the room's visibility can be modified, False if not. + """ + if self.third_party_rules is None: + return True + + check_func = getattr(self.third_party_rules, "check_visibility_can_be_modified") + if not check_func or not isinstance(check_func, Callable): + return True + + state_events = await self._get_state_map_for_room(room_id) + + return await check_func(room_id, state_events, new_visibility) + + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: + """Given a room ID, return the state events of that room. + + Args: + room_id: The ID of the room. + + Returns: + A dict mapping (event type, state key) to state event. + """ state_ids = await self.store.get_filtered_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) @@ -113,7 +155,4 @@ class ThirdPartyEventRules: for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) - return ret + return state_events diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 32c73d3413..355cbe05f1 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase: pruned_event_dict, event.room_version, event.internal_metadata.get_dict() ) + # copy the internal fields + pruned_event.internal_metadata.stream_ordering = ( + event.internal_metadata.stream_ordering + ) + # Mark the event as redacted pruned_event.internal_metadata.redacted = True diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 688d43fffb..302b2f69bc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,10 +24,12 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Sequence, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -501,7 +503,7 @@ class FederationClient(FederationBase): user_id: str, membership: str, content: dict, - params: Dict[str, str], + params: Optional[Mapping[str, Union[str, Iterable[str]]]], ) -> Tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2dcd081cbc..02f11e1209 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,13 +22,12 @@ from typing import ( Callable, Dict, List, - Match, Optional, Tuple, Union, ) -from prometheus_client import Counter, Histogram +from prometheus_client import Counter, Gauge, Histogram from twisted.internet import defer from twisted.internet.abstract import isIPAddress @@ -88,6 +87,13 @@ pdu_process_time = Histogram( ) +last_pdu_age_metric = Gauge( + "synapse_federation_last_received_pdu_age", + "The age (in seconds) of the last PDU successfully received from the given domain", + labelnames=("server_name",), +) + + class FederationServer(FederationBase): def __init__(self, hs): super().__init__(hs) @@ -118,6 +124,10 @@ class FederationServer(FederationBase): hs, "state_ids_resp", timeout_ms=30000 ) + self._federation_metrics_domains = ( + hs.get_config().federation.federation_metrics_domains + ) + async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: @@ -262,7 +272,11 @@ class FederationServer(FederationBase): pdus_by_room = {} # type: Dict[str, List[EventBase]] + newest_pdu_ts = 0 + for p in transaction.pdus: # type: ignore + # FIXME (richardv): I don't think this works: + # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: unsigned = p["unsigned"] if "age" in unsigned: @@ -300,6 +314,9 @@ class FederationServer(FederationBase): event = event_from_pdu_json(p, room_version) pdus_by_room.setdefault(room_id, []).append(event) + if event.origin_server_ts > newest_pdu_ts: + newest_pdu_ts = event.origin_server_ts + pdu_results = {} # we can process different rooms in parallel (which is useful if they @@ -340,6 +357,10 @@ class FederationServer(FederationBase): process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) + if newest_pdu_ts and origin in self._federation_metrics_domains: + newest_pdu_age = self._clock.time_msec() - newest_pdu_ts + last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000) + return pdu_results async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): @@ -803,14 +824,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: return False -def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: +def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False regex = glob_to_regex(acl_entry) - return regex.match(server_name) + return bool(regex.match(server_name)) class FederationHandlerRegistry: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8bb17b3a05..e33b29a42c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -297,6 +297,8 @@ class FederationSender: sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_count.inc() + assert pdu.internal_metadata.stream_ordering + # track the fact that we have a PDU for these destinations, # to allow us to perform catch-up later on if the remote is unreachable # for a while. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 2657767fd1..db8e456fe8 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -158,6 +158,7 @@ class PerDestinationQueue: # yet know if we have anything to catch up (None) self._pending_pdus.append(pdu) else: + assert pdu.internal_metadata.stream_ordering self._catchup_last_skipped = pdu.internal_metadata.stream_ordering self.attempt_new_transaction() @@ -361,6 +362,7 @@ class PerDestinationQueue: last_successful_stream_ordering = ( final_pdu.internal_metadata.stream_ordering ) + assert last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( self._destination, last_successful_stream_ordering ) @@ -490,7 +492,7 @@ class PerDestinationQueue: ) if logger.isEnabledFor(logging.INFO): - rooms = (p.room_id for p in catchup_pdus) + rooms = [p.room_id for p in catchup_pdus] logger.info("Catching up rooms to %s: %r", self._destination, rooms) success = await self._transaction_manager.send_new_transaction( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index c84072ab73..3e07f925e0 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,6 +15,8 @@ import logging from typing import TYPE_CHECKING, List +from prometheus_client import Gauge + from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -34,6 +36,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +last_pdu_age_metric = Gauge( + "synapse_federation_last_sent_pdu_age", + "The age (in seconds) of the last PDU successfully sent to the given domain", + labelnames=("server_name",), +) + class TransactionManager: """Helper class which handles building and sending transactions @@ -48,6 +56,10 @@ class TransactionManager: self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() + self._federation_metrics_domains = ( + hs.get_config().federation.federation_metrics_domains + ) + # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -119,6 +131,9 @@ class TransactionManager: # FIXME (erikj): This is a bit of a hack to make the Pdu age # keys work + # FIXME (richardv): I also believe it no longer works. We (now?) store + # "age_ts" in "unsigned" rather than at the top level. See + # https://github.com/matrix-org/synapse/issues/8429. def json_data_cb(): data = transaction.get_dict() now = int(self.clock.time_msec()) @@ -167,5 +182,12 @@ class TransactionManager: ) success = False + if success and pdus and destination in self._federation_metrics_domains: + last_pdu = pdus[-1] + last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts + last_pdu_age_metric.labels(server_name=destination).set( + last_pdu_age / 1000 + ) + set_tag(tags.ERROR, not success) return success diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index dd981c597e..1ce2091b46 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -153,7 +153,7 @@ class AdminHandler(BaseHandler): if not events: break - from_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + from_key = events[-1].internal_metadata.after events = await filter_events_for_client(self.storage, user_id, events) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0322b60cfc..7c4b716b28 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: } +@attr.s(slots=True) +class SsoLoginExtraAttributes: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib(type=int) + extra_attributes = attr.ib(type=JsonDict) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -203,7 +212,7 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. - if hs.config.worker_app is None: + if hs.config.run_background_tasks: self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, @@ -239,6 +248,10 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + # A mapping of user ID to extra attributes to include in the login + # response. + self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + async def validate_user_via_ui_auth( self, requester: Requester, @@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler): registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler): request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. + extra_attributes: Extra attributes which will be passed to the client + during successful login. Must be JSON serializable. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return - self._complete_sso_login(registered_user_id, request, client_redirect_url) + self._complete_sso_login( + registered_user_id, request, client_redirect_url, extra_attributes + ) def _complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + # Store any extra attributes which will be passed in the login response. + # Note that this is per-user so it may overwrite a previous value, this + # is considered OK since the newest SSO attributes should be most valid. + if extra_attributes: + self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( + self._clock.time_msec(), extra_attributes, + ) + # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id @@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) + async def _sso_login_callback(self, login_result: JsonDict) -> None: + """ + A login callback which might add additional attributes to the login response. + + Args: + login_result: The data to be sent to the client. Includes the user + ID and access token. + """ + # Expire attributes before processing. Note that there shouldn't be any + # valid logins that still have extra attributes. + self._expire_sso_extra_attributes() + + extra_attributes = self._extra_attributes.get(login_result["user_id"]) + if extra_attributes: + login_result.update(extra_attributes.extra_attributes) + + def _expire_sso_extra_attributes(self) -> None: + """ + Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid. + """ + # TODO This should match the amount of time the macaroon is valid for. + LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000 + expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME + to_expire = set() + for user_id, data in self._extra_attributes.items(): + if data.creation_time < expire_before: + to_expire.add(user_id) + for user_id in to_expire: + logger.debug("Expiring extra attributes for user %s", user_id) + del self._extra_attributes[user_id] + @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index e97fbeed6c..028fc0cd6f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -40,7 +40,6 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( JsonDict, - RoomStreamToken, StreamToken, UserID, get_domain_from_id, @@ -130,8 +129,7 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_id = self.store.get_room_max_stream_ordering() - now_room_key = RoomStreamToken(None, now_room_id) + now_room_key = self.store.get_room_max_token() room_ids = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 62aa9a2da8..ad5683d251 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -46,6 +46,7 @@ class DirectoryHandler(BaseHandler): self.config = hs.config self.enable_room_list_search = hs.config.enable_room_list_search self.require_membership = hs.config.require_membership_for_aliases + self.third_party_event_rules = hs.get_third_party_event_rules() self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -383,7 +384,7 @@ class DirectoryHandler(BaseHandler): """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator is not None and creator == user_id: + if creator == user_id: return True # Resolve the alias to the corresponding room. @@ -454,6 +455,15 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") + # Check if publishing is blocked by a third party module + allowed_by_third_party_rules = await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) + ) + if not allowed_by_third_party_rules: + raise SynapseError(403, "Not allowed to publish room") + await self.store.set_room_is_public(room_id, making_public) async def edit_published_appservice_room_list( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 0875b74ea8..539b4fc32e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler): chunk = { "chunk": chunks, - "start": tokens[0].to_string(), - "end": tokens[1].to_string(), + "start": await tokens[0].to_string(self.store), + "end": await tokens[1].to_string(self.store), } return chunk diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9f773aefa7..5ac2fc5656 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,7 +21,7 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -69,7 +69,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnInviteRestServlet, ) -from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, @@ -85,6 +85,9 @@ from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -116,7 +119,7 @@ class FederationHandler(BaseHandler): rooms. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs @@ -126,6 +129,7 @@ class FederationHandler(BaseHandler): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() + self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() self.action_generator = hs.get_action_generator() @@ -155,8 +159,9 @@ class FederationHandler(BaseHandler): self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite - # When joining a room we need to queue any events for that room up - self.room_queues = {} + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] self._room_pdu_linearizer = Linearizer("fed_room_pdu") self.third_party_event_rules = hs.get_third_party_event_rules() @@ -281,7 +286,7 @@ class FederationHandler(BaseHandler): raise Exception( "Error fetching missing prev_events for %s: %s" % (event_id, e) - ) + ) from e # Update the set of things we've seen after trying to # fetch the missing stuff @@ -380,8 +385,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await resolve_events_with_store( - self.clock, + state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, @@ -814,6 +818,9 @@ class FederationHandler(BaseHandler): dest, room_id, limit=limit, extremities=extremities ) + if not events: + return [] + # ideally we'd sanity check the events here for excess prev_events etc, # but it's hard to reject events at this point without completely # breaking backfill in the same way that it is currently broken by @@ -2164,10 +2171,10 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = await self.state_store.get_state_groups( + state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets.values()) + state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event @@ -2958,6 +2965,7 @@ class FederationHandler(BaseHandler): ) return result["max_stream_id"] else: + assert self.storage.persistence max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -3000,6 +3008,9 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ab15570f7a..bc3e9607ca 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,8 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from twisted.internet.error import TimeoutError - from synapse.api.errors import ( CodeMessageException, Codes, @@ -30,6 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester from synapse.util import json_decoder @@ -93,7 +92,7 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.get_json(url, query_params) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.info( @@ -173,7 +172,7 @@ class IdentityHandler(BaseHandler): if e.code != 404 or not use_v2: logger.error("3PID bind failed with Matrix error: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json_decoder.decode(e.msg) # XXX WAT? @@ -273,7 +272,7 @@ class IdentityHandler(BaseHandler): else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") await self.store.remove_user_bound_threepid( @@ -419,7 +418,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") async def requestMsisdnToken( @@ -471,7 +470,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") assert self.hs.config.public_baseurl @@ -553,7 +552,7 @@ class IdentityHandler(BaseHandler): id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", body, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) @@ -627,7 +626,7 @@ class IdentityHandler(BaseHandler): # require or validate it. See the following for context: # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 return data["mxid"] - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except IOError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) @@ -655,7 +654,7 @@ class IdentityHandler(BaseHandler): "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), {"access_token": id_access_token}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") if not isinstance(hash_details, dict): @@ -727,7 +726,7 @@ class IdentityHandler(BaseHandler): }, headers=headers, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except Exception as e: logger.warning("Error when performing a v2 3pid lookup: %s", e) @@ -823,7 +822,7 @@ class IdentityHandler(BaseHandler): invite_config, {"Authorization": create_id_access_token_header(id_access_token)}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: if e.code != 404: @@ -841,7 +840,7 @@ class IdentityHandler(BaseHandler): data = await self.blacklisting_http_client.post_json_get_json( url, invite_config ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8cd7eb22a3..39a85801c1 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler): messages, time_now=time_now, as_client_event=as_client_event ) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), } d["state"] = await self._event_serializer.serialize_events( @@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler): ], "account_data": account_data_events, "receipts": receipt, - "end": now_token.to_string(), + "end": await now_token.to_string(self.store), } return ret @@ -325,7 +325,8 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = await self.store.get_stream_token_for_event(member_event_id) + leave_position = await self.store.get_position_for_event(member_event_id) + stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token @@ -347,8 +348,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": ( await self._event_serializer.serialize_events( @@ -446,8 +447,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": state, "presence": presence, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ee271e85e5..00513fbf37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -682,7 +682,9 @@ class EventCreationHandler: event.event_id, prev_event.event_id, ) - return await self.store.get_stream_id_for_event(prev_event.event_id) + # we know it was persisted, so must have a stream ordering + assert prev_event.internal_metadata.stream_ordering + return prev_event.internal_metadata.stream_ordering return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4230dbaf99..05ac86e697 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -37,7 +37,7 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -96,6 +96,7 @@ class OidcHandler: self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] + self._user_profile_method = hs.config.oidc_user_profile_method # type: str self._client_auth = ClientAuth( hs.config.oidc_client_id, hs.config.oidc_client_secret, @@ -114,6 +115,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -195,11 +197,11 @@ class OidcHandler: % (m["response_types_supported"],) ) - # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos + # Ensure there's a userinfo endpoint to fetch from if it is required. if self._uses_userinfo: if m.get("userinfo_endpoint") is None: raise ValueError( - 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested' + 'provider has no "userinfo_endpoint", even though it is required' ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token @@ -219,8 +221,10 @@ class OidcHandler: ``access_token`` with the ``userinfo_endpoint``. """ - # Maybe that should be user-configurable and not inferred? - return "openid" not in self._scopes + return ( + "openid" not in self._scopes + or self._user_profile_method == "userinfo_endpoint" + ) async def load_metadata(self) -> OpenIDProviderMetadata: """Load and validate the provider metadata. @@ -706,6 +710,15 @@ class OidcHandler: self._render_error(request, "mapping_error", str(e)) return + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + # and finally complete the login if ui_auth_session_id: await self._auth_handler.complete_sso_ui_auth( @@ -713,7 +726,7 @@ class OidcHandler: ) else: await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url + user_id, request, client_redirect_url, extra_attributes ) def _generate_oidc_session_token( @@ -849,7 +862,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +919,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) @@ -972,7 +996,7 @@ class OidcMappingProvider(Generic[C]): async def map_user_attributes( self, userinfo: UserInfo, token: Token ) -> UserAttribute: - """Map a ``UserInfo`` objects into user attributes. + """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider @@ -983,6 +1007,18 @@ class OidcMappingProvider(Generic[C]): """ raise NotImplementedError() + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + """Map a `UserInfo` object into additional attributes passed to the client during login. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing additional attributes. Must be JSON serializable. + """ + return {} + # Used to clear out "None" values in templates def jinja_finalize(thing): @@ -997,6 +1033,7 @@ class JinjaOidcMappingConfig: subject_claim = attr.ib() # type: str localpart_template = attr.ib() # type: Template display_name_template = attr.ib() # type: Optional[Template] + extra_attributes = attr.ib() # type: Dict[str, Template] class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1035,10 +1072,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): % (e,) ) + extra_attributes = {} # type Dict[str, Template] + if "extra_attributes" in config: + extra_attributes_config = config.get("extra_attributes") or {} + if not isinstance(extra_attributes_config, dict): + raise ConfigError( + "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" + ) + + for key, value in extra_attributes_config.items(): + try: + extra_attributes[key] = env.from_string(value) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" + % (key, e) + ) + return JinjaOidcMappingConfig( subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + extra_attributes=extra_attributes, ) def get_remote_user_id(self, userinfo: UserInfo) -> str: @@ -1059,3 +1114,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name = None return UserAttribute(localpart=localpart, display_name=display_name) + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + extras = {} # type: Dict[str, str] + for key, template in self._config.extra_attributes.items(): + try: + extras[key] = template.render(user=userinfo).strip() + except Exception as e: + # Log an error and skip this value (don't break login for this). + logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e)) + return extras diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a0b3bdb5e0..2c2a633938 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import Requester, RoomStreamToken +from synapse.types import Requester from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -373,10 +373,9 @@ class PaginationHandler: # case "JOIN" would have been returned. assert member_event_id - leave_token_str = await self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token_str) assert leave_token.topological is not None if leave_token.topological < curr_topo: @@ -414,8 +413,8 @@ class PaginationHandler: if not events: return { "chunk": [], - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } state = None @@ -443,8 +442,8 @@ class PaginationHandler: events, time_now, as_client_event=as_client_event ) ), - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } if state: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11bf146bed..f14f791586 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -681,6 +681,15 @@ class RoomCreationHandler(BaseHandler): creator_id=user_id, is_public=is_public, room_version=room_version, ) + # Check whether this visibility value is blocked by a third party module + allowed_by_third_party_rules = await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) + ) + if not allowed_by_third_party_rules: + raise SynapseError(403, "Room visibility value not allowed.") + directory_handler = self.hs.get_handlers().directory_handler if room_alias: await directory_handler.create_association( @@ -962,8 +971,6 @@ class RoomCreationHandler(BaseHandler): try: random_string = stringutils.random_string(18) gen_room_id = RoomID(random_string, self.hs.hostname).to_string() - if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode("utf-8") await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -1077,11 +1084,13 @@ class RoomContextHandler: # the token, which we replace. token = StreamToken.START - results["start"] = token.copy_and_replace( + results["start"] = await token.copy_and_replace( "room_key", results["start"] - ).to_string() + ).to_string(self.store) - results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() + results["end"] = await token.copy_and_replace( + "room_key", results["end"] + ).to_string(self.store) return results @@ -1134,14 +1143,14 @@ class RoomEventSource: events[:] = events[:limit] if events: - end_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + end_key = events[-1].internal_metadata.after else: end_key = to_key return (events, end_key) def get_current_key(self) -> RoomStreamToken: - return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) + return self.store.get_room_max_token() def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8feba8c90a..13b749b7cb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 from synapse import types -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -194,8 +194,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - _, stream_id = await self.store.get_event_ordering(duplicate.event_id) - return duplicate.event_id, stream_id + # we know it was persisted, so must have a stream ordering. + assert duplicate.internal_metadata.stream_ordering + return duplicate.event_id, duplicate.internal_metadata.stream_ordering prev_state_ids = await context.get_prev_state_ids() @@ -247,7 +248,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user_account_data, _ = await self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable - direct_rooms = user_account_data.get("m.direct", {}) + direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {}) # Check which key this room is under if isinstance(direct_rooms, dict): @@ -258,7 +259,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Save back to user's m.direct account data await self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms + user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -441,12 +442,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - _, stream_id = await self.store.get_event_ordering( - old_state.event_id - ) + # duplicate event. + # we know it was persisted, so must have a stream ordering. + assert old_state.internal_metadata.stream_ordering return ( old_state.event_id, - stream_id, + old_state.internal_metadata.stream_ordering, ) if old_membership in ["ban", "leave"] and action == "kick": @@ -642,7 +643,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def send_membership_event( self, - requester: Requester, + requester: Optional[Requester], event: EventBase, context: EventContext, ratelimit: bool = True, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6a76c20d79..e9402e6e2e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -362,13 +362,13 @@ class SearchHandler(BaseHandler): self.storage, user.to_string(), res["events_after"] ) - res["start"] = now_token.copy_and_replace( + res["start"] = await now_token.copy_and_replace( "room_key", res["start"] - ).to_string() + ).to_string(self.store) - res["end"] = now_token.copy_and_replace( + res["end"] = await now_token.copy_and_replace( "room_key", res["end"] - ).to_string() + ).to_string(self.store) if include_profile: senders = { diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 249ffe2a55..dc62b21c06 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -49,7 +49,7 @@ class StatsHandler: # Guard to ensure we only process deltas one at a time self._is_processing = False - if hs.config.stats_enabled: + if self.stats_enabled and hs.config.run_background_tasks: self.notifier.add_replication_callback(self.notify_new_event) # We kick this off so that we don't have to wait for a change before diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e948efef2e..a998e6b7f6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase from synapse.logging.context import current_context @@ -87,7 +87,7 @@ class SyncConfig: class TimelineBatch: prev_batch = attr.ib(type=StreamToken) events = attr.ib(type=List[EventBase]) - limited = attr.ib(bool) + limited = attr.ib(type=bool) def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -519,7 +519,7 @@ class SyncHandler: if len(recents) > timeline_limit: limited = True recents = recents[-timeline_limit:] - room_key = RoomStreamToken.parse(recents[0].internal_metadata.before) + room_key = recents[0].internal_metadata.before prev_batch_token = now_token.copy_and_replace("room_key", room_key) @@ -1378,13 +1378,16 @@ class SyncHandler: return set(), set(), set(), set() ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id ) + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users = frozenset() # type: FrozenSet[str] if ignored_account_data: - ignored_users = ignored_account_data.get("ignored_users", {}).keys() - else: - ignored_users = frozenset() + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) if since_token: room_changes = await self._get_rooms_changed( @@ -1478,7 +1481,7 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Gets the the changes that have happened since the last sync. """ @@ -1595,16 +1598,24 @@ class SyncHandler: if leave_events: leave_event = leave_events[-1] - leave_stream_token = await self.store.get_stream_token_for_event( + leave_position = await self.store.get_position_for_event( leave_event.event_id ) - leave_token = since_token.copy_and_replace( - "room_key", leave_stream_token - ) - if since_token and since_token.is_after(leave_token): + # If the leave event happened before the since token then we + # bail. + if since_token and not leave_position.persisted_after( + since_token.room_key + ): continue + # We can safely convert the position of the leave event into a + # stream token as it'll only be used in the context of this + # room. (c.f. the docstring of `to_room_stream_token`). + leave_token = since_token.copy_and_replace( + "room_key", leave_position.to_room_stream_token() + ) + # If this is an out of band message, like a remote invite # rejection, we include it in the recents batch. Otherwise, we # let _load_filtered_recents handle fetching the correct @@ -1682,7 +1693,7 @@ class SyncHandler: return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1756,7 +1767,7 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: Set[str], + ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 8eb3638591..59b01b812c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -16,8 +16,6 @@ import re from twisted.internet import task -from twisted.internet.defer import CancelledError -from twisted.python import failure from twisted.web.client import FileBodyProducer from synapse.api.errors import SynapseError @@ -26,19 +24,8 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self): - super().__init__(504, "Timed out") - - -def cancelled_to_request_timed_out_error(value, timeout): - """Turns CancelledErrors into RequestTimedOutErrors. - - For use with async.add_timeout_to_deferred - """ - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise RequestTimedOutError() - return value + def __init__(self, msg): + super().__init__(504, msg) ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") diff --git a/synapse/http/client.py b/synapse/http/client.py index 4694adc400..8324632cb6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib from io import BytesIO @@ -38,7 +37,7 @@ from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, ssl +from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, @@ -46,17 +45,18 @@ from twisted.internet.interfaces import ( from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import ( + Agent, + HTTPConnectionPool, + ResponseNeverReceived, + readBody, +) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError -from synapse.http import ( - QuieterFileBodyProducer, - cancelled_to_request_timed_out_error, - redact_uri, -) +from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags @@ -332,8 +332,6 @@ class SimpleHttpClient: RequestTimedOutError if the request times out before the headers are read """ - # A small wrapper around self.agent.request() so we can easily attach - # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) @@ -362,15 +360,17 @@ class SimpleHttpClient: data=body_producer, headers=headers, **self._extra_treq_args - ) + ) # type: defer.Deferred + # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), - cancelled_to_request_timed_out_error, + request_deferred, 60, self.hs.get_reactor(), ) + + # turn timeouts into RequestTimedOutErrors + request_deferred.addErrback(_timeout_to_request_timed_out_error) + response = await make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() @@ -410,7 +410,7 @@ class SimpleHttpClient: parsed json Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -461,7 +461,7 @@ class SimpleHttpClient: parsed json Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -506,7 +506,7 @@ class SimpleHttpClient: Returns: Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -538,7 +538,7 @@ class SimpleHttpClient: Returns: Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -586,7 +586,7 @@ class SimpleHttpClient: Succeeds when we get a 2xx HTTP response, with the HTTP body as bytes. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -631,7 +631,7 @@ class SimpleHttpClient: headers, absolute URI of the response and HTTP response code. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -684,6 +684,18 @@ class SimpleHttpClient: ) +def _timeout_to_request_timed_out_error(f: Failure): + if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): + # The TCP connection has its own timeout (set by the 'connectTimeout' param + # on the Agent), which raises twisted_error.TimeoutError exception. + raise RequestTimedOutError("Timeout connecting to remote server") + elif f.check(defer.TimeoutError, ResponseNeverReceived): + # this one means that we hit our overall timeout on the request + raise RequestTimedOutError("Timeout waiting for response from remote server") + + return f + + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3c86cbc546..c23a4d7c0c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -171,7 +171,7 @@ async def _handle_json_response( d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = await make_deferred_yieldable(d) - except TimeoutError as e: + except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", request.txn_id, @@ -473,8 +473,6 @@ class MatrixFederationHttpClient: ) response = await request_deferred - except TimeoutError as e: - raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: @@ -657,10 +655,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as @@ -706,8 +708,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -736,10 +743,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -803,10 +814,14 @@ class MatrixFederationHttpClient: args (dict|None): A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -842,8 +857,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -867,10 +887,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -902,8 +926,13 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 332da02a8d..e32d3f43e0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -44,8 +44,11 @@ class ProxyAgent(_AgentBase): `BrowserLikePolicyForHTTPS`, so unless you have special requirements you can leave this as-is. - connectTimeout (float): The amount of time that this Agent will wait - for the peer to accept a connection. + connectTimeout (Optional[float]): The amount of time that this Agent will wait + for the peer to accept a connection, in seconds. If 'None', + HostnameEndpoint's default (30s) will be used. + + This is used for connections to both proxies and destination servers. bindAddress (bytes): The local address for client sockets to bind to. @@ -108,6 +111,15 @@ class ProxyAgent(_AgentBase): Returns: Deferred[IResponse]: completes when the header of the response has been received (regardless of the response status code). + + Can fail with: + SchemeNotSupported: if the uri is not http or https + + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. + + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/http/server.py b/synapse/http/server.py index 996a31a9ec..09ed74f6ce 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 144506c8f2..0fc2ea609e 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import os.path import sys @@ -89,14 +88,7 @@ class LogContextObserver: context = current_context() # Copy the context information to the log event. - if context is not None: - context.copy_to_twisted_log_entry(event) - else: - # If there's no logging context, not even the root one, we might be - # starting up or it might be from non-Synapse code. Log it as if it - # came from the root logger. - event["request"] = None - event["scope"] = None + context.copy_to_twisted_log_entry(event) self.observer(event) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2e282d9d67..ca0c774cc5 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -65,6 +65,11 @@ except Exception: return None +# a hook which can be set during testing to assert that we aren't abusing logcontexts. +def logcontext_error(msg: str): + logger.warning(msg) + + # get an id for the current thread. # # threading.get_ident doesn't actually return an OS-level tid, and annoyingly, @@ -330,10 +335,9 @@ class LoggingContext: """Enters this logging context into thread local storage""" old_context = set_current_context(self) if self.previous_context != old_context: - logger.warning( - "Expected previous context %r, found %r", - self.previous_context, - old_context, + logcontext_error( + "Expected previous context %r, found %r" + % (self.previous_context, old_context,) ) return self @@ -346,10 +350,10 @@ class LoggingContext: current = set_current_context(self.previous_context) if current is not self: if current is SENTINEL_CONTEXT: - logger.warning("Expected logging context %s was lost", self) + logcontext_error("Expected logging context %s was lost" % (self,)) else: - logger.warning( - "Expected logging context %s but found %s", self, current + logcontext_error( + "Expected logging context %s but found %s" % (self, current) ) # the fact that we are here suggests that the caller thinks that everything @@ -387,16 +391,16 @@ class LoggingContext: support getrusuage. """ if get_thread_id() != self.main_thread: - logger.warning("Started logcontext %s on different thread", self) + logcontext_error("Started logcontext %s on different thread" % (self,)) return if self.finished: - logger.warning("Re-starting finished log context %s", self) + logcontext_error("Re-starting finished log context %s" % (self,)) # If we haven't already started record the thread resource usage so # far if self.usage_start: - logger.warning("Re-starting already-active log context %s", self) + logcontext_error("Re-starting already-active log context %s" % (self,)) else: self.usage_start = rusage @@ -414,7 +418,7 @@ class LoggingContext: try: if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) + logcontext_error("Stopped logcontext %s on different thread" % (self,)) return if not rusage: @@ -422,9 +426,9 @@ class LoggingContext: # Record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without recording a start rusage", - self, + logcontext_error( + "Called stop on logcontext %s without recording a start rusage" + % (self,) ) return @@ -584,14 +588,13 @@ class PreserveLoggingContext: if context != self._new_context: if not context: - logger.warning( - "Expected logging context %s was lost", self._new_context + logcontext_error( + "Expected logging context %s was lost" % (self._new_context,) ) else: - logger.warning( - "Expected logging context %s but found %s", - self._new_context, - context, + logcontext_error( + "Expected logging context %s but found %s" + % (self._new_context, context,) ) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index a1f7ca3449..b8d2a8e8a9 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -15,6 +15,7 @@ import functools import gc +import itertools import logging import os import platform @@ -27,8 +28,8 @@ from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( REGISTRY, CounterMetricFamily, + GaugeHistogramMetricFamily, GaugeMetricFamily, - HistogramMetricFamily, ) from twisted.internet import reactor @@ -46,7 +47,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]] +all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -205,63 +206,83 @@ class InFlightGauge: all_gauges[self.name] = self -@attr.s(slots=True, hash=True) -class BucketCollector: - """ - Like a Histogram, but allows buckets to be point-in-time instead of - incrementally added to. +class GaugeBucketCollector: + """Like a Histogram, but the buckets are Gauges which are updated atomically. - Args: - name (str): Base name of metric to be exported to Prometheus. - data_collector (callable -> dict): A synchronous callable that - returns a dict mapping bucket to number of items in the - bucket. If these buckets are not the same as the buckets - given to this class, they will be remapped into them. - buckets (list[float]): List of floats/ints of the buckets to - give to Prometheus. +Inf is ignored, if given. + The data is updated by calling `update_data` with an iterable of measurements. + We assume that the data is updated less frequently than it is reported to + Prometheus, and optimise for that case. """ - name = attr.ib() - data_collector = attr.ib() - buckets = attr.ib() + __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") + + def __init__( + self, + name: str, + documentation: str, + buckets: Iterable[float], + registry=REGISTRY, + ): + """ + Args: + name: base name of metric to be exported to Prometheus. (a _bucket suffix + will be added.) + documentation: help text for the metric + buckets: The top bounds of the buckets to report + registry: metric registry to register with + """ + self._name = name + self._documentation = documentation + + # the tops of the buckets + self._bucket_bounds = [float(b) for b in buckets] + if self._bucket_bounds != sorted(self._bucket_bounds): + raise ValueError("Buckets not in sorted order") + + if self._bucket_bounds[-1] != float("inf"): + self._bucket_bounds.append(float("inf")) + + self._metric = self._values_to_metric([]) + registry.register(self) def collect(self): + yield self._metric - # Fetch the data -- this must be synchronous! - data = self.data_collector() + def update_data(self, values: Iterable[float]): + """Update the data to be reported by the metric - buckets = {} # type: Dict[float, int] + The existing data is cleared, and each measurement in the input is assigned + to the relevant bucket. + """ + self._metric = self._values_to_metric(values) - res = [] - for x in data.keys(): - for i, bound in enumerate(self.buckets): - if x <= bound: - buckets[bound] = buckets.get(bound, 0) + data[x] + def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily: + total = 0.0 + bucket_values = [0 for _ in self._bucket_bounds] - for i in self.buckets: - res.append([str(i), buckets.get(i, 0)]) + for v in values: + # assign each value to a bucket + for i, bound in enumerate(self._bucket_bounds): + if v <= bound: + bucket_values[i] += 1 + break - res.append(["+Inf", sum(data.values())]) + # ... and increment the sum + total += v - metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) + # now, aggregate the bucket values so that they count the number of entries in + # that bucket or below. + accumulated_values = itertools.accumulate(bucket_values) + + return GaugeHistogramMetricFamily( + self._name, + self._documentation, + buckets=list( + zip((str(b) for b in self._bucket_bounds), accumulated_values) + ), + gsum_value=total, ) - yield metric - - def __attrs_post_init__(self): - self.buckets = [float(x) for x in self.buckets if x != "+Inf"] - if self.buckets != sorted(self.buckets): - raise ValueError("Buckets not sorted") - - self.buckets = tuple(self.buckets) - - if self.name in all_gauges.keys(): - logger.warning("%s already registered, reregistering" % (self.name,)) - REGISTRY.unregister(all_gauges.pop(self.name)) - - REGISTRY.register(self) - all_gauges[self.name] = self # diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 4304c60d56..734271e765 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -24,9 +24,9 @@ expect, and the newer "best practice" version of the up-to-date official client. import math import threading -from collections import namedtuple from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn +from typing import Dict, List from urllib.parse import parse_qs, urlparse from prometheus_client import REGISTRY @@ -35,14 +35,6 @@ from twisted.web.resource import Resource from synapse.util import caches -try: - from prometheus_client.samples import Sample -except ImportError: - Sample = namedtuple( # type: ignore[no-redef] # noqa - "Sample", ["name", "labels", "value", "timestamp", "exemplar"] - ) - - CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") @@ -93,17 +85,6 @@ def sample_line(line, name): ) -def nameify_sample(sample): - """ - If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a - namedtuple which has the names we expect. - """ - if not isinstance(sample, Sample): - sample = Sample(*sample, None, None) - - return sample - - def generate_latest(registry, emit_help=False): # Trigger the cache metrics to be rescraped, which updates the common @@ -144,16 +125,33 @@ def generate_latest(registry, emit_help=False): ) ) output.append("# TYPE {0} {1}\n".format(mname, mtype)) - for sample in map(nameify_sample, metric.samples): - # Get rid of the OpenMetrics specific samples + + om_samples = {} # type: Dict[str, List[str]] + for s in metric.samples: for suffix in ["_created", "_gsum", "_gcount"]: - if sample.name.endswith(suffix): + if s.name == metric.name + suffix: + # OpenMetrics specific sample, put in a gauge at the end. + # (these come from gaugehistograms which don't get renamed, + # so no need to faff with mnewname) + om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) break else: - newname = sample.name.replace(mnewname, mname) + newname = s.name.replace(mnewname, mname) if ":" in newname and newname.endswith("_total"): newname = newname[: -len("_total")] - output.append(sample_line(sample, newname)) + output.append(sample_line(s, newname)) + + for suffix, lines in sorted(om_samples.items()): + if emit_help: + output.append( + "# HELP {0}{1} {2}\n".format( + metric.name, + suffix, + metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), + ) + ) + output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix)) + output.extend(lines) # Get rid of the weird colon things while we're at it if mtype == "counter": @@ -172,16 +170,16 @@ def generate_latest(registry, emit_help=False): ) ) output.append("# TYPE {0} {1}\n".format(mnewname, mtype)) - for sample in map(nameify_sample, metric.samples): - # Get rid of the OpenMetrics specific samples + + for s in metric.samples: + # Get rid of the OpenMetrics specific samples (we should already have + # dealt with them above anyway.) for suffix in ["_created", "_gsum", "_gcount"]: - if sample.name.endswith(suffix): + if s.name == metric.name + suffix: break else: output.append( - sample_line( - sample, sample.name.replace(":total", "").replace(":", "_") - ) + sample_line(s, s.name.replace(":total", "").replace(":", "_")) ) return "".join(output).encode("utf-8") diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index fcbd5378c4..646f09d2bc 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -14,13 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.internet import defer +from synapse.http.client import SimpleHttpClient from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + """ This package defines the 'stable' API which can be used by extension modules which are loaded into Synapse. @@ -43,6 +48,27 @@ class ModuleApi: self._auth = hs.get_auth() self._auth_handler = auth_handler + # We expose these as properties below in order to attach a helpful docstring. + self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient + self._public_room_list_manager = PublicRoomListManager(hs) + + @property + def http_client(self): + """Allows making outbound HTTP requests to remote resources. + + An instance of synapse.http.client.SimpleHttpClient + """ + return self._http_client + + @property + def public_room_list_manager(self): + """Allows adding to, removing from and checking the status of rooms in the + public room list. + + An instance of synapse.module_api.PublicRoomListManager + """ + return self._public_room_list_manager + def get_user_by_req(self, req, allow_guest=False): """Check the access_token provided for a request @@ -266,3 +292,44 @@ class ModuleApi: await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url, ) + + +class PublicRoomListManager: + """Contains methods for adding to, removing from and querying whether a room + is in the public room list. + """ + + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + + async def room_is_in_public_room_list(self, room_id: str) -> bool: + """Checks whether a room is in the public room list. + + Args: + room_id: The ID of the room. + + Returns: + Whether the room is in the public room list. Returns False if the room does + not exist. + """ + room = await self._store.get_room(room_id) + if not room: + return False + + return room.get("is_public", False) + + async def add_room_to_public_room_list(self, room_id: str) -> None: + """Publishes a room to the public room list. + + Args: + room_id: The ID of the room. + """ + await self._store.set_room_is_public(room_id, True) + + async def remove_room_from_public_room_list(self, room_id: str) -> None: + """Removes a room from the public room list. + + Args: + room_id: The ID of the room. + """ + await self._store.set_room_is_public(room_id, False) diff --git a/synapse/notifier.py b/synapse/notifier.py index 441b3d15e2..59415f6f88 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -163,7 +163,7 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token.is_after(token): + if self.last_notified_token != token: return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) @@ -470,7 +470,7 @@ class Notifier: async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: - if not after_token.is_after(before_token): + if after_token == before_token: return EventStreamResult([], (from_token, from_token)) events = [] # type: List[EventBase] diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 709ace01e5..3a68ce636f 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,7 +16,7 @@ import logging import re -from typing import Any, Dict, List, Pattern, Union +from typing import Any, Dict, List, Optional, Pattern, Union from synapse.events import EventBase from synapse.types import UserID @@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent: return r.search(body) - def _get_value(self, dotted_key: str) -> str: + def _get_value(self, dotted_key: str) -> Optional[str]: return self._value_cache.get(dotted_key, None) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 288631477e..0ddead8a0f 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -68,7 +68,11 @@ REQUIREMENTS = [ "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", - "prometheus_client>=0.0.18,<0.9.0", + # we use GaugeHistogramMetric, which was added in prom-client 0.4.0. + # prom-client has a history of breaking backwards compatibility between + # minor versions (https://github.com/prometheus/client_python/issues/317), + # so we also pin the minor version. + "prometheus_client>=0.4.0,<0.9.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 # is out in November.) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index b448da6710..64edadb624 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -20,18 +20,28 @@ import urllib from inspect import signature from typing import Dict, List, Tuple -from synapse.api.errors import ( - CodeMessageException, - HttpResponseException, - RequestSendFailed, - SynapseError, -) +from prometheus_client import Counter, Gauge + +from synapse.api.errors import HttpResponseException, SynapseError +from synapse.http import RequestTimedOutError from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) +_pending_outgoing_requests = Gauge( + "synapse_pending_outgoing_replication_requests", + "Number of active outgoing replication requests, by replication method name", + ["name"], +) + +_outgoing_request_counter = Counter( + "synapse_outgoing_replication_requests", + "Number of outgoing replication requests, by replication method name and result", + ["name", "code"], +) + class ReplicationEndpoint(metaclass=abc.ABCMeta): """Helper base class for defining new replication HTTP endpoints. @@ -138,7 +148,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): instance_map = hs.config.worker.instance_map + outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + @trace(opname="outgoing_replication_request") + @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -193,23 +206,26 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): try: result = await request_func(uri, data, headers=headers) break - except CodeMessageException as e: - if e.code != 504 or not cls.RETRY_ON_TIMEOUT: + except RequestTimedOutError: + if not cls.RETRY_ON_TIMEOUT: raise - logger.warning("%s request timed out", cls.NAME) + logger.warning("%s request timed out; retrying", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. await clock.sleep(1) except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError - # on the master process that we should send to the client. (And + # on the main process that we should send to the client. (And # importantly, not stack traces everywhere) + _outgoing_request_counter.labels(cls.NAME, e.code).inc() raise e.to_synapse_error() - except RequestSendFailed as e: - raise SynapseError(502, "Failed to talk to master") from e + except Exception as e: + _outgoing_request_counter.labels(cls.NAME, "ERR").inc() + raise SynapseError(502, "Failed to talk to main process") from e + _outgoing_request_counter.labels(cls.NAME, 200).inc() return result return send_request diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 55af3d41ea..e165429cad 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import PersistedEventPosition, RoomStreamToken, UserID +from synapse.types import PersistedEventPosition, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -152,9 +152,7 @@ class ReplicationDataHandler: if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = RoomStreamToken( - None, self.store.get_room_max_stream_ordering() - ) + max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) self.notifier.on_new_room_event( event, event_pos, max_token, extra_users diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..e92da7b263 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -251,10 +251,9 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: - import txredisapi - from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, + lazyConnection, ) logger.info( @@ -271,7 +270,8 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = txredisapi.lazyConnection( + outbound_redis_connection = lazyConnection( + reactor=hs.get_reactor(), host=hs.config.redis_host, port=hs.config.redis_port, password=hs.config.redis.redis_password, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0b0d204e64..a509e599c2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -51,10 +51,11 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 - self.time_we_closed = None # When we requested the connection be closed + # When we requested the connection be closed + self.time_we_closed = None # type: Optional[int] - self.received_ping = False # Have we reecived a ping from the other side + self.received_ping = False # Have we received a ping from the other side self.state = ConnectionStates.CONNECTING @@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. - self._send_ping_loop = None + self._send_ping_loop = None # type: Optional[task.LoopingCall] # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index f225e533de..de19705c1f 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import txredisapi @@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.password = self.password return p + + +def lazyConnection( + reactor, + host: str = "localhost", + port: int = 6379, + dbid: Optional[int] = None, + reconnect: bool = True, + charset: str = "utf-8", + password: Optional[str] = None, + connectTimeout: Optional[int] = None, + replyTimeout: Optional[int] = None, + convertNumbers: bool = True, +) -> txredisapi.RedisProtocol: + """Equivalent to `txredisapi.lazyConnection`, except allows specifying a + reactor. + """ + + isLazy = True + poolsize = 1 + + uuid = "%s:%d" % (host, port) + factory = txredisapi.RedisFactory( + uuid, + dbid, + poolsize, + isLazy, + txredisapi.ConnectionHandler, + charset, + password, + replyTimeout, + convertNumbers, + ) + factory.continueTrying = reconnect + for x in range(poolsize): + reactor.connectTCP(host, port, factory, connectTimeout) + + return factory.handler diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html new file mode 100644 index 0000000000..baf4633142 --- /dev/null +++ b/synapse/res/templates/auth_success.html @@ -0,0 +1,21 @@ + + +Success! + + + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html new file mode 100644 index 0000000000..63944dc608 --- /dev/null +++ b/synapse/res/templates/recaptcha.html @@ -0,0 +1,38 @@ + + +Authentication + + + + + + + +
+
+

+ Hello! We need to prevent computer programs and other automated + things from creating accounts on this server. +

+

+ Please verify that you're not a robot. +

+ +
+
+ +
+ +
+ + diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index af8459719a..944bc9c9ca 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -12,7 +12,7 @@

There was an error during authentication:

-
{{ error_description }}
+
{{ error_description | e }}

If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html new file mode 100644 index 0000000000..dfef9897ee --- /dev/null +++ b/synapse/res/templates/terms.html @@ -0,0 +1,20 @@ + + +Authentication + + + + +

+
+

+ Please click the button below if you agree to the + privacy policy of this homeserver. +

+ + +
+
+ + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 5c5f00b213..789431ef25 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -57,6 +57,7 @@ from synapse.rest.admin.users import ( UsersRestServletV2, WhoisRestServlet, ) +from synapse.types import RoomStreamToken from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -109,7 +110,10 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = await self.store.get_topological_token_for_event(event_id) + room_token = RoomStreamToken( + event.depth, event.internal_metadata.stream_ordering + ) + token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 985d994f6b..1ecb77aa26 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet): if b"room_id" in request.args: room_id = request.args[b"room_id"][0].decode("ascii") - pagin_config = PaginationConfig.from_request(request) + pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in request.args: try: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index d7042786ce..91da0ee573 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 250b03a025..3d1693d7ac 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -111,8 +111,6 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) - flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) - return 200, {"flows": flows} def on_OPTIONS(self, request: SynapseRequest): @@ -284,9 +282,7 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[ - Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] - ] = None, + callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to @@ -299,12 +295,12 @@ class LoginRestServlet(RestServlet): Args: user_id: ID of the user to register. login_submission: Dictionary of login information. - callback: Callback function to run after registration. + callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. Returns: - result: Dictionary of account information after successful registration. + result: Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in @@ -339,14 +335,24 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + + Returns: + The body of the JSON response. + """ token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = await self._complete_login(user_id, login_submission) - return result + return await self._complete_login( + user_id, login_submission, self.auth_handler._sso_login_callback + ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7e64a2e0fe..b63389e5fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): if at_token_string is None: at_token = None else: - at_token = StreamToken.from_string(at_token_string) + at_token = await StreamToken.from_string(self.store, at_token_string) # let you filter down on particular memberships. # XXX: this may not be the best shape for this API - we could pass in a filter @@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request, default_limit=10) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) as_client_event = b"raw" not in request.args filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: @@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c3ce0f6259..ab5815e7f7 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -96,15 +96,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email domain is not authorized on this server", - Codes.THREEPID_DENIED, - ) - - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -379,8 +373,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) existing_user_id = await self.store.get_user_id_by_threepid("email", email) @@ -453,8 +448,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 097538f968..5fbfae5991 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -25,94 +25,6 @@ from ._base import client_patterns logger = logging.getLogger(__name__) -RECAPTCHA_TEMPLATE = """ - - -Authentication - - - - - - - -
-
-

- Hello! We need to prevent computer programs and other automated - things from creating accounts on this server. -

-

- Please verify that you're not a robot. -

- -
-
- -
- -
- - -""" - -TERMS_TEMPLATE = """ - - -Authentication - - - - -
-
-

- Please click the button below if you agree to the - privacy policy of this homeserver. -

- - -
-
- - -""" - -SUCCESS_TEMPLATE = """ - - -Success! - - - - - -
-

Thank you

-

You may now close this window and return to the application

-
- - -""" - class AuthRestServlet(RestServlet): """ @@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self.recaptcha_template = hs.config.recaptcha_template + self.terms_template = hs.config.terms_template + self.success_template = hs.config.fallback_success_template + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") if stagetype == LoginType.RECAPTCHA: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to @@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 7abd6ff333..55c4606569 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet): # changes after the "to" as well as before. set_tag("to", parse_string(request, "to")) - from_token = StreamToken.from_string(from_token_string) + from_token = await StreamToken.from_string(self.store, from_token_string) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 51e395cc64..6779df952f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() @@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet): device_id=device_id, ) + since_token = None if since is not None: - since_token = StreamToken.from_string(since) - else: - since_token = None + since_token = await StreamToken.from_string(self.store, since) # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), + "next_batch": await sync_result.next_batch.to_string(self.store), } @staticmethod @@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet): result = { "timeline": { "events": serialized_timeline, - "prev_batch": room.timeline.prev_batch.to_string(), + "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, "state": {"events": serialized_state}, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 69f353d46f..e1192b47cd 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -139,7 +139,7 @@ class MediaRepository: async def create_content( self, media_type: str, - upload_name: str, + upload_name: Optional[str], content: IO, content_length: int, auth_user: str, @@ -147,8 +147,8 @@ class MediaRepository: """Store uploaded content for a local user and return the mxc URL Args: - media_type: The content type of the file - upload_name: The name of the file + media_type: The content type of the file. + upload_name: The name of the file, if provided. content: A file like object that is the content to store content_length: The length of the content auth_user: The user_id of the uploader @@ -156,6 +156,7 @@ class MediaRepository: Returns: The mxc url of the stored content """ + media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) @@ -636,7 +637,7 @@ class MediaRepository: thumbnailer = Thumbnailer(input_path) except ThumbnailError as e: logger.warning( - "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s", + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", media_id, server_name, media_type, diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 5681677fc9..a9586fb0b7 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -141,31 +141,34 @@ class MediaStorage: Returns: Returns a Responder if the file was found, otherwise None. """ + paths = [self._file_info_to_path(file_info)] - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return FileResponder(open(local_path, "rb")) - - # Fallback for paths without method names - # Should be removed in the future + # fallback for remote thumbnails with no method in the filename if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail_width, - height=file_info.thumbnail_height, - content_type=file_info.thumbnail_type, + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return FileResponder(open(legacy_local_path, "rb")) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: - res = await provider.fetch(path, file_info) # type: Any - if res: - logger.debug("Streaming %s from %s", path, provider) - return res + for path in paths: + res = await provider.fetch(path, file_info) # type: Any + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) return None diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 3ebf7a68e6..d76f7389e1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource): msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) + # If the name is falsey (e.g. an empty byte string) ensure it is None. + else: + upload_name = None + headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): diff --git a/synapse/server.py b/synapse/server.py index 5e3752c333..aa2273955c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -185,7 +185,10 @@ class HomeServer(metaclass=abc.ABCMeta): we are listening on to provide HTTP services. """ - REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + REQUIRED_ON_BACKGROUND_TASK_STARTUP = [ + "auth", + "stats", + ] # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be @@ -251,14 +254,20 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def setup_master(self) -> None: + # Register background tasks required by this server. This must be done + # somewhat manually due to the background tasks not being registered + # unless handlers are instantiated. + if self.config.run_background_tasks: + self.setup_background_tasks() + + def setup_background_tasks(self) -> None: """ Some handlers have side effects on instantiation (like registering background updates). This function causes them to be fetched, and therefore instantiated, to run those side effects. """ - for i in self.REQUIRED_ON_MASTER_STARTUP: - getattr(self, "get_" + i)() + for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: + getattr(self, "get_" + i + "_handler")() def get_reactor(self) -> twisted.internet.base.ReactorBase: """ diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5a5ea39e01..5b0900aa3c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -13,42 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import heapq import logging -from collections import namedtuple +from collections import defaultdict, namedtuple from typing import ( + Any, Awaitable, + Callable, + DefaultDict, Dict, Iterable, List, Optional, Sequence, Set, + Tuple, Union, overload, ) import attr from frozendict import frozendict -from prometheus_client import Histogram +from prometheus_client import Counter, Histogram from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.logging.context import ContextResourceUsage from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo from synapse.types import Collection, StateMap -from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) - +metrics_logger = logging.getLogger("synapse.state.metrics") # Metrics for number of state groups involved in a resolution. state_groups_histogram = Histogram( @@ -448,19 +452,44 @@ class StateHandler: state_map = {ev.event_id: ev for st in state_sets for ev in st} - with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - event.room_id, - room_version, - state_set_ids, - event_map=state_map, - state_res_store=StateResolutionStore(self.store), - ) + new_state = await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_set_ids, + event_map=state_map, + state_res_store=StateResolutionStore(self.store), + ) return {key: state_map[ev_id] for key, ev_id in new_state.items()} +@attr.s(slots=True) +class _StateResMetrics: + """Keeps track of some usage metrics about state res.""" + + # System and User CPU time, in seconds + cpu_time = attr.ib(type=float, default=0.0) + + # time spent on database transactions (excluding scheduling time). This roughly + # corresponds to the amount of work done on the db server, excluding event fetches. + db_time = attr.ib(type=float, default=0.0) + + # number of events fetched from the db. + db_events = attr.ib(type=int, default=0) + + +_biggest_room_by_cpu_counter = Counter( + "synapse_state_res_cpu_for_biggest_room_seconds", + "CPU time spent performing state resolution for the single most expensive " + "room for state resolution", +) +_biggest_room_by_db_counter = Counter( + "synapse_state_res_db_for_biggest_room_seconds", + "Database time spent performing state resolution for the single most " + "expensive room for state resolution", +) + + class StateResolutionHandler: """Responsible for doing state conflict resolution. @@ -483,6 +512,17 @@ class StateResolutionHandler: reset_expiry_on_get=True, ) + # + # stuff for tracking time spent on state-res by room + # + + # tracks the amount of work done on state res per room + self._state_res_metrics = defaultdict( + _StateResMetrics + ) # type: DefaultDict[str, _StateResMetrics] + + self.clock.looping_call(self._report_metrics, 120 * 1000) + @log_function async def resolve_state_groups( self, @@ -530,15 +570,13 @@ class StateResolutionHandler: state_groups_histogram.observe(len(state_groups_ids)) - with Measure(self.clock, "state._resolve_events"): - new_state = await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, - ) + new_state = await self.resolve_events_with_store( + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ) # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -552,6 +590,114 @@ class StateResolutionHandler: return cache + async def resolve_events_with_store( + self, + room_id: str, + room_version: str, + state_sets: Sequence[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", + ) -> StateMap[str]: + """ + Args: + room_id: the room we are working in + + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map: + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_res_store. + + state_res_store: a place to fetch events from + + Returns: + a map from (type, state_key) to event_id. + """ + try: + with Measure(self.clock, "state._resolve_events") as m: + v = KNOWN_ROOM_VERSIONS[room_version] + if v.state_res == StateResolutionVersions.V1: + return await v1.resolve_events_with_store( + room_id, state_sets, event_map, state_res_store.get_events + ) + else: + return await v2.resolve_events_with_store( + self.clock, + room_id, + room_version, + state_sets, + event_map, + state_res_store, + ) + finally: + self._record_state_res_metrics(room_id, m.get_resource_usage()) + + def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage): + room_metrics = self._state_res_metrics[room_id] + room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime + room_metrics.db_time += rusage.db_txn_duration_sec + room_metrics.db_events += rusage.evt_db_fetch_count + + def _report_metrics(self): + if not self._state_res_metrics: + # no state res has happened since the last iteration: don't bother logging. + return + + self._report_biggest( + lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter, + ) + + self._report_biggest( + lambda i: i.db_time, "DB time", _biggest_room_by_db_counter, + ) + + self._state_res_metrics.clear() + + def _report_biggest( + self, + extract_key: Callable[[_StateResMetrics], Any], + metric_name: str, + prometheus_counter_metric: Counter, + ) -> None: + """Report metrics on the biggest rooms for state res + + Args: + extract_key: a callable which, given a _StateResMetrics, extracts a single + metric to sort by. + metric_name: the name of the metric we have extracted, for the log line + prometheus_counter_metric: a prometheus metric recording the sum of the + the extracted metric + """ + n_to_log = 10 + if not metrics_logger.isEnabledFor(logging.DEBUG): + # only need the most expensive if we don't have debug logging, which + # allows nlargest() to degrade to max() + n_to_log = 1 + + items = self._state_res_metrics.items() + + # log the N biggest rooms + biggest = heapq.nlargest( + n_to_log, items, key=lambda i: extract_key(i[1]) + ) # type: List[Tuple[str, _StateResMetrics]] + metrics_logger.debug( + "%i biggest rooms for state-res by %s: %s", + len(biggest), + metric_name, + ["%s (%gs)" % (r, extract_key(m)) for (r, m) in biggest], + ) + + # report info on the single biggest to prometheus + _, biggest_metrics = biggest[0] + prometheus_counter_metric.inc(extract_key(biggest_metrics)) + def _make_state_cache_entry( new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] @@ -592,7 +738,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids = None + delta_ids = None # type: Optional[StateMap[str]] for old_group, old_state in state_groups_ids.items(): n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} @@ -605,47 +751,6 @@ def _make_state_cache_entry( ) -def resolve_events_with_store( - clock: Clock, - room_id: str, - room_version: str, - state_sets: Sequence[StateMap[str]], - event_map: Optional[Dict[str, EventBase]], - state_res_store: "StateResolutionStore", -) -> Awaitable[StateMap[str]]: - """ - Args: - room_id: the room we are working in - - room_version: Version of the room - - state_sets: List of dicts of (type, state_key) -> event_id, - which are the different state groups to resolve. - - event_map: - a dict from event_id to event, for any events that we happen to - have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing - events will be requested via state_map_factory. - - If None, all events will be fetched via state_res_store. - - state_res_store: a place to fetch events from - - Returns: - a map from (type, state_key) to event_id. - """ - v = KNOWN_ROOM_VERSIONS[room_version] - if v.state_res == StateResolutionVersions.V1: - return v1.resolve_events_with_store( - room_id, state_sets, event_map, state_res_store.get_events - ) - else: - return v2.resolve_events_with_store( - clock, room_id, room_version, state_sets, event_map, state_res_store - ) - - @attr.s(slots=True) class StateResolutionStore: """Interface that allows state resolution algorithms to access the database diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 79ec8f119d..0d9d9b7cc0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -90,13 +91,17 @@ def make_pool( return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), **db_config.config.get("args", {}) ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +114,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +301,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +455,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -418,12 +478,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -584,7 +642,10 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) @@ -1621,7 +1682,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1642,9 +1703,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index aa5d490624..0c24325011 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -46,7 +46,7 @@ class Databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine) as db_conn: + with make_conn(database_config, engine, "startup") as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0cb12f4c61..9b16f45f3e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -268,9 +266,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -289,7 +284,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) @@ -301,192 +295,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index ef81d73573..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f211ddbaf8..4bb2b9c28c 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.frozenutils import frozendict_json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 62f1738732..80f3b4d740 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -74,11 +74,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 18def01f50..b4abd961b9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -341,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -743,7 +737,9 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = frozendict_json_encoder.encode( + event.internal_metadata.get_dict() + ) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -797,10 +793,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": frozendict_json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": frozendict_json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f95679ebc4..b7ed8ca6ab 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -74,6 +74,13 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -522,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: @@ -712,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -779,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -811,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -831,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 686052bd83..2c5a4fdbf6 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import Counter +import calendar +import logging +import time +from typing import Dict -from synapse.metrics import BucketCollector +from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -23,6 +25,28 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + +# Collect metrics on the number of forward extremities that exist. +_extremities_collecter = GaugeBucketCollector( + "synapse_forward_extremities", + "Number of rooms on the server with the given number of forward extremities" + " or fewer", + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], +) + +# we also expose metrics on the "number of excess extremity events", which is +# (E-1)*N, where E is the number of extremities and N is the number of state +# events in the room. This is an approximation to the number of state events +# we could remove from state resolution by reducing the graph to a single +# forward extremity. +_excess_state_events_collecter = GaugeBucketCollector( + "synapse_excess_extremity_events", + "Number of rooms on the server with the given number of excess extremity " + "events, or fewer", + buckets=[0] + [1 << n for n in range(12)], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -32,18 +56,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - # Read the extrems every 60 minutes def read_forward_extremities(): # run as a background process to make sure that the database transactions @@ -54,18 +66,32 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + async def _read_forward_extremities(self): def fetch(txn): txn.execute( """ - select count(*) c from event_forward_extremities - group by room_id + SELECT t1.c, t2.c + FROM ( + SELECT room_id, COUNT(*) c FROM event_forward_extremities + GROUP BY room_id + ) t1 LEFT JOIN ( + SELECT room_id, COUNT(*) c FROM current_state_events + GROUP BY room_id + ) t2 ON t1.room_id = t2.room_id """ ) return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + _extremities_collecter.update_data(x[0] for x in res) + + _excess_state_events_collecter.update_data( + (x[0] - 1) * x[1] for x in res if x[1] + ) async def count_daily_messages(self): """ @@ -120,3 +146,189 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e0cedd1aac..c66f558567 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,6 +32,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -41,7 +44,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + # Exclude app service users + sql = """ + SELECT COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users + ON monthly_active_users.user_id=users.name + WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); + """ txn.execute(sql) (count,) = txn.fetchone() return count @@ -117,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -250,6 +206,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d7a03cbf7d..ecfc6717b3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): The set of state groups that are referenced by deleted events. """ + parsed_token = await RoomStreamToken.parse(self, token) + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, - token, + parsed_token, delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - + def _purge_history_txn(self, txn, room_id, token, delete_local_events): # Tables that should be pruned: # event_auth # event_backward_extremities diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 33825e8949..a83df7759d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -41,6 +41,9 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) @@ -393,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -401,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3c7630857f..c0f2af0785 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 86ffe2479e..bae1bd22d3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -21,12 +21,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -60,10 +55,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres index 97c1e6a0c5..c31f9af82a 100644 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', ( CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events )); diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 92e96468b4..a94bec1ac5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,7 +35,6 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple @@ -54,7 +53,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import Collection, RoomStreamToken +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -305,6 +304,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() + def get_room_max_token(self) -> RoomStreamToken: + return RoomStreamToken(None, self.get_room_max_stream_ordering()) + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], @@ -544,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -587,19 +589,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: @@ -611,26 +600,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: - """The stream token for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream token. + async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: + """Get the persisted position for an event """ - stream_id = await self.get_stream_id_for_event(event_id) - return RoomStreamToken(None, stream_id) + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "instance_name"), + desc="get_position_for_event", + ) - async def get_topological_token_for_event(self, event_id: str) -> str: + return PersistedEventPosition( + row["instance_name"] or "master", row["stream_ordering"] + ) + + async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "t%d-%d" topological token. + A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", @@ -638,7 +629,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -687,8 +678,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: topo = None internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) + internal.before = RoomStreamToken(topo, stream - 1) + internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bec3780a32..0e31cc811a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_seq_gen = build_sequence_generator( self.database_engine, get_max_state_group_txn, "state_group_id_seq" ) + self._state_group_seq_gen.check_consistency( + db_conn, table="state_groups", id_column="id" + ) @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group): @@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 603cd7d825..4d2d88d1f0 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -197,7 +197,7 @@ class EventsPersistenceStorage: async def persist_events( self, - events_and_contexts: List[Tuple[EventBase, EventContext]], + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> RoomStreamToken: """ @@ -229,7 +229,7 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return RoomStreamToken(None, self.main_store.get_current_events_token()) + return self.main_store.get_room_max_token() async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False @@ -247,11 +247,12 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) - max_persisted_id = self.main_store.get_current_events_token() event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id pos = PersistedEventPosition(self._instance_name, event_stream_id) - return pos, RoomStreamToken(None, max_persisted_id) + return pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4957e77f4c..459754feab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import imp import logging import os @@ -24,9 +23,10 @@ from typing import Optional, TextIO import attr from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.types import Collection logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( def prepare_database( - db_conn: Connection, + db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ["main", "state"], @@ -89,7 +89,7 @@ def prepare_database( """ try: - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="prepare_database") # sqlite does not automatically start transactions for DDL / SELECT statements, # so we start one before running anything. This ensures that any upgrades @@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases): executescript(cur, entry.absolute_path) cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (max_current_ver, False), ) @@ -486,17 +484,13 @@ def _upgrade_existing_database( # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)" - ), + "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)", (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (v, True), ) @@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) schemas to be applied """ cur.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_module_schemas WHERE module_name = ?" - ), - (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" - ), + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)", (modname, name), ) @@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine): if current_version: txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), + "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) applied_deltas = [d for d, in txn] diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8f68d968f0..08a69f2f96 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -349,7 +349,7 @@ class StateGroupStorage: async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: @@ -532,7 +532,7 @@ class StateGroupStorage: def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Awaitable[Dict[int, StateMap[str]]]: + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 2d2b560e74..970bb1b9da 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -61,3 +61,9 @@ class Connection(Protocol): def rollback(self, *args, **kwargs) -> None: ... + + def __enter__(self) -> "Connection": + ... + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 727fcc521c..51f680d05d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -54,7 +54,7 @@ def _load_current_id(db_conn, table, column, step=1): """ # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: @@ -258,22 +258,37 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + # This goes and fills out the above state from the database. self._load_current_ids(db_conn, table, instance_column, id_column) def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ): - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + cur.execute(sql, (self._stream_name, self._writers)) + sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? """ - sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (self._stream_name,)) self._current_positions = { @@ -287,8 +302,12 @@ class MultiWriterIdGenerator: min_stream_id = min(self._current_positions.values(), default=None) if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). sql = """ - SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s """ % { "id": id_column, "table": table, @@ -318,8 +337,7 @@ class MultiWriterIdGenerator: "instance": instance_column, "cmp": "<=" if self._positive else ">=", } - sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (min_stream_id,)) + cur.execute(sql, (min_stream_id * self._return_factor,)) self._persisted_upto_position = min_stream_id @@ -399,7 +417,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None + new_cur = None # type: Optional[int] if self._unfinished_ids: # If there are unfinished IDs then the new position will be the @@ -444,11 +462,22 @@ class MultiWriterIdGenerator: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ffc1894748..ff2d038ad2 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -13,11 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import threading from typing import Callable, List, Optional -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.engines import ( + BaseDatabaseEngine, + IncorrectDatabaseSetup, + PostgresEngine, +) +from synapse.storage.types import Connection, Cursor + +logger = logging.getLogger(__name__) + + +_INCONSISTENT_SEQUENCE_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated +table '%(table)s'. This can happen if Synapse has been downgraded and +then upgraded again, or due to a bad migration. + +To fix this error, shut down Synapse (including any and all workers) +and run the following SQL: + + SELECT setval('%(seq)s', ( + %(max_id_sql)s + )); + +See docs/postgres.md for more information. +""" class SequenceGenerator(metaclass=abc.ABCMeta): @@ -28,6 +52,23 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def check_consistency( + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, + ): + """Should be called during start up to test that the current value of + the sequence is greater than or equal to the maximum ID in the table. + + This is to handle various cases where the sequence value can get out + of sync with the table, e.g. if Synapse gets rolled back to a previous + version and the rolled forwards again. + """ + ... + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -45,6 +86,54 @@ class PostgresSequenceGenerator(SequenceGenerator): ) return [i for (i,) in txn] + def check_consistency( + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, + ): + txn = db_conn.cursor(txn_name="sequence.check_consistency") + + # First we get the current max ID from the table. + table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { + "id": id_column, + "table": table, + "agg": "MAX" if positive else "-MIN", + } + + txn.execute(table_sql) + row = txn.fetchone() + if not row: + # Table is empty, so nothing to do. + txn.close() + return + + # Now we fetch the current value from the sequence and compare with the + # above. + max_stream_id = row[0] + txn.execute( + "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} + ) + last_value, is_called = txn.fetchone() + txn.close() + + # If `is_called` is False then `last_value` is actually the value that + # will be generated next, so we decrement to get the true "last value". + if not is_called: + last_value -= 1 + + if max_stream_id > last_value: + logger.warning( + "Postgres sequence %s is behind table %s: %d < %d", + last_value, + max_stream_id, + ) + raise IncorrectDatabaseSetup( + _INCONSISTENT_SEQUENCE_ERROR + % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -81,6 +170,12 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + # There is nothing to do for in memory sequences + pass + def build_sequence_generator( database_engine: BaseDatabaseEngine, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 0bdf846edf..fdda21d165 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -21,6 +20,7 @@ import attr from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.storage.databases.main import DataStore from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -39,8 +39,9 @@ class PaginationConfig: limit = attr.ib(type=Optional[int]) @classmethod - def from_request( + async def from_request( cls, + store: "DataStore", request: SynapseRequest, raise_invalid_params: bool = True, default_limit: Optional[int] = None, @@ -54,13 +55,13 @@ class PaginationConfig: if from_tok == "END": from_tok = None # For backwards compat. elif from_tok: - from_tok = StreamToken.from_string(from_tok) + from_tok = await StreamToken.from_string(store, from_tok) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: - to_tok = StreamToken.from_string(to_tok) + to_tok = await StreamToken.from_string(store, to_tok) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types.py b/synapse/types.py index ec39f9e1e8..bd271f9f16 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,17 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, +) import attr from signedjson.key import decode_verify_key_bytes @@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection @@ -393,7 +406,7 @@ class RoomStreamToken: stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string: str) -> "RoomStreamToken": + async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -413,10 +426,22 @@ class RoomStreamToken: pass raise SynapseError(400, "Invalid token %r" % (string,)) + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + return RoomStreamToken(None, max_stream) + def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) - def __str__(self) -> str: + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: @@ -441,48 +466,51 @@ class StreamToken: START = None # type: StreamToken @classmethod - def from_string(cls, string): + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": try: keys = string.split(cls._SEPARATOR) while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) except Exception: raise SynapseError(400, "Invalid Token") - def to_string(self): - return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + str(self.groups_key), + ] + ) @property def room_stream_id(self): return self.room_key.stream - def is_after(self, other): - """Does this token contain events that the other doesn't?""" - return ( - (other.room_stream_id < self.room_stream_id) - or (int(other.presence_key) < int(self.presence_key)) - or (int(other.typing_key) < int(self.typing_key)) - or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.account_data_key) < int(self.account_data_key)) - or (int(other.push_rules_key) < int(self.push_rules_key)) - or (int(other.to_device_key) < int(self.to_device_key)) - or (int(other.device_list_key) < int(self.device_list_key)) - or (int(other.groups_key) < int(self.groups_key)) - ) - def copy_and_advance(self, key, new_value) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. """ - new_token = self.copy_and_replace(key, new_value) if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + if old_id < new_id: return new_token else: @@ -492,7 +520,7 @@ class StreamToken: return attr.evolve(self, **{key: new_value}) -StreamToken.START = StreamToken.from_string("s0_0") +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) @attr.s(slots=True, frozen=True) @@ -509,6 +537,18 @@ class PersistedEventPosition: def persisted_after(self, token: RoomStreamToken) -> bool: return token.stream < self.stream + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarentees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 67ce9a5f39..382f0cf3f0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -449,18 +449,8 @@ class ReadWriteLock: R = TypeVar("R") -def _cancelled_to_timed_out_error(value: R, timeout: float) -> R: - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise defer.TimeoutError(timeout, "Deferred") - return value - - def timeout_deferred( - deferred: defer.Deferred, - timeout: float, - reactor: IReactorTime, - on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None, + deferred: defer.Deferred, timeout: float, reactor: IReactorTime, ) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -469,27 +459,21 @@ def timeout_deferred( (See https://twistedmatrix.com/trac/ticket/9534) - NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred. + + NOTE: the TimeoutError raised by the resultant deferred is + twisted.internet.defer.TimeoutError, which is *different* to the built-in + TimeoutError, as well as various other TimeoutErrors you might have imported. Args: deferred: The Deferred to potentially timeout. timeout: Timeout in seconds reactor: The twisted reactor to use - on_timeout_cancel: A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. - - The default callable (if none is provided) will translate a - CancelledError Failure into a defer.TimeoutError. Returns: - A new Deferred. + A new Deferred, which will errback with defer.TimeoutError on timeout. """ - new_d = defer.Deferred() timed_out = [False] @@ -502,18 +486,23 @@ def timeout_deferred( except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") + # the cancel() call should have set off a chain of errbacks which + # will have errbacked new_d, but in case it hasn't, errback it now. + if not new_d.called: - new_d.errback(defer.TimeoutError(timeout, "Deferred")) + new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) + def convert_cancelled(value: failure.Failure): + # if the orgininal deferred was cancelled, and our timeout has fired, then + # the reason it was cancelled was due to our timeout. Turn the CancelledError + # into a TimeoutError. + if timed_out[0] and value.check(CancelledError): + raise defer.TimeoutError("Timed out after %gs" % (timeout,)) return value - deferred.addBoth(convert_cancelled) + deferred.addErrback(convert_cancelled) def cancel_timeout(result): # stop the pending call to cancel the deferred if it's been fired diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 6e57c1ee72..ffdea0de8d 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -19,7 +19,11 @@ from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter -from synapse.logging.context import LoggingContext, current_context +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + current_context, +) from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -104,27 +108,27 @@ class Measure: def __init__(self, clock, name): self.clock = clock self.name = name - self._logging_context = None - self.start = None - - def __enter__(self): - if self._logging_context: - raise RuntimeError("Measure() objects cannot be re-used") - - self.start = self.clock.time() parent_context = current_context() self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) + self.start = None + + def __enter__(self) -> "Measure": + if self.start is not None: + raise RuntimeError("Measure() objects cannot be re-used") + + self.start = self.clock.time() self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) + return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self._logging_context: + if self.start is None: raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start - usage = self._logging_context.get_resource_usage() + usage = self.get_resource_usage() in_flight.unregister((self.name,), self._update_in_flight) self._logging_context.__exit__(exc_type, exc_val, exc_tb) @@ -140,6 +144,13 @@ class Measure: except ValueError: logger.warning("Failed to save metrics! Usage: %s", usage) + def get_resource_usage(self) -> ContextResourceUsage: + """Get the resources used within this Measure block + + If the Measure block is still active, returns the resource usage so far. + """ + return self._logging_context.get_resource_usage() + def _update_in_flight(self, metrics): """Gets called when processing in flight metrics """ diff --git a/synapse/visibility.py b/synapse/visibility.py index e3da7744d2..527365498e 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,7 +16,7 @@ import logging import operator -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage from synapse.storage.state import StateFilter @@ -77,15 +77,14 @@ async def filter_events_for_client( ) ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id + AccountDataTypes.IGNORED_USER_LIST, user_id ) - # FIXME: This will explode if people upload something incorrect. - ignore_list = frozenset( - ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content - else [] - ) + ignore_list = frozenset() + if ignore_dict_content: + ignored_users_dict = ignore_dict_content.get("ignored_users", {}) + if isinstance(ignored_users_dict, dict): + ignore_list = frozenset(ignored_users_dict.keys()) erased_senders = await storage.main.are_users_erased((e.sender for e in events)) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2e6e7abf1f..8ff1460c0d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -23,6 +23,7 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer +from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError from synapse.crypto import keyring @@ -33,7 +34,6 @@ from synapse.crypto.keyring import ( ) from synapse.logging.context import ( LoggingContext, - PreserveLoggingContext, current_context, make_deferred_yieldable, ) @@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import logcontext_clean class MockPerspectiveServer: @@ -67,55 +68,42 @@ class MockPerspectiveServer: signedjson.sign.sign_json(res, self.server_name, self.key) +@logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - self.mock_perspective_server = MockPerspectiveServer() - self.http_client = Mock() - - config = self.default_config() - config["trusted_key_servers"] = [ - { - "server_name": self.mock_perspective_server.server_name, - "verify_keys": self.mock_perspective_server.get_verify_keys(), - } - ] - - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) - - def check_context(self, _, expected): + def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) + return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - key1 = signedjson.key.generate_signing_key(1) + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock() + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) - kr = keyring.Keyring(self.hs) + # a signed object that we are going to try to validate + key1 = signedjson.key.generate_signing_key(1) json1 = {} signedjson.sign.sign_json(json1, "server10", key1) - persp_resp = { - "server_keys": [ - self.mock_perspective_server.get_signed_key( - "server10", signedjson.key.get_verify_key(key1) - ) - ] - } - persp_deferred = defer.Deferred() + # start off a first set of lookups. We make the mock fetcher block until this + # deferred completes. + first_lookup_deferred = Deferred() - async def get_perspectives(**kwargs): - self.assertEquals(current_context().request, "11") - with PreserveLoggingContext(): - await persp_deferred - return persp_resp + async def first_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_11") + self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) - self.http_client.post_json.side_effect = get_perspectives + await make_deferred_yieldable(first_lookup_deferred) + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } - # start off a first set of lookups - @defer.inlineCallbacks - def first_lookup(): - with LoggingContext("11") as context_11: - context_11.request = "11" + mock_fetcher.get_keys.side_effect = first_lookup_fetch + + async def first_lookup(): + with LoggingContext("context_11") as context_11: + context_11.request = "context_11" res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] @@ -124,7 +112,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: - yield res_deferreds[1] + await res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass @@ -132,45 +120,51 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds[0]) + await make_deferred_yieldable(res_deferreds[0]) - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) + d0 = ensureDeferred(first_lookup()) - d0 = first_lookup() - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - self.pump() - self.http_client.post_json.assert_called_once() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - @defer.inlineCallbacks - def second_lookup(): - with LoggingContext("12") as context_12: - context_12.request = "12" - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + + async def second_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_12") + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } + + mock_fetcher.get_keys.reset_mock() + mock_fetcher.get_keys.side_effect = second_lookup_fetch + second_lookup_state = [0] + + async def second_lookup(): + with LoggingContext("context_12") as context_12: + context_12.request = "context_12" res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 1 + await make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 2 - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) - - d2 = second_lookup() + d2 = ensureDeferred(second_lookup()) self.pump() - self.http_client.post_json.assert_not_called() + # the second request should be pending, but the fetcher should not yet have been + # called + self.assertEqual(second_lookup_state[0], 1) + mock_fetcher.get_keys.assert_not_called() # complete the first request - persp_deferred.callback(persp_resp) + first_lookup_deferred.callback(None) + + # and now both verifications should succeed. self.get_success(d0) self.get_success(d2) @@ -317,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): mock_fetcher2.get_keys.assert_called_once() +@logcontext_clean class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 89ec5fcb31..b6f436c016 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -21,7 +21,6 @@ from mock import Mock, patch import attr import pymacaroons -from twisted.internet import defer from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider): async def map_user_attributes(self, userinfo, token): return {"localpart": userinfo["username"], "display_name": None} + # Do not include get_extra_attributes to test backwards compatibility paths. + + +class TestMappingProviderExtra(TestMappingProvider): + async def get_extra_attributes(self, userinfo, token): + return {"phone": userinfo["phone"]} + def simple_async_mock(return_value=None, raises=None): # AsyncMock is not available in python3.5, this mimics part of its behaviour @@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase): config = self.default_config() config["public_baseurl"] = BASE_URL - oidc_config = config.get("oidc_config", {}) + oidc_config = {} oidc_config["enabled"] = True oidc_config["client_id"] = CLIENT_ID oidc_config["client_secret"] = CLIENT_SECRET @@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config["user_mapping_provider"] = { "module": __name__ + ".TestMappingProvider", } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config hs = self.setup_test_homeserver( @@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) - @defer.inlineCallbacks def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + metadata = self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + jwks = self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks()) + self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - with self.assertRaises(RuntimeError): - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_failure(self.handler.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used self.handler._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @@ -280,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase): h._validate_metadata, ) - # Tests for configs that the userinfo endpoint + # Tests for configs that require the userinfo endpoint self.assertFalse(h._uses_userinfo) - h._scopes = [] # do not request the openid scope + self.assertEqual(h._user_profile_method, "auto") + h._user_profile_method = "userinfo_endpoint" + self.assertTrue(h._uses_userinfo) + + # Revert the profile method and do not request the "openid" scope. + h._user_profile_method = "auto" + h._scopes = [] self.assertTrue(h._uses_userinfo) self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) @@ -299,11 +311,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # This should not throw self.handler._validate_metadata() - @defer.inlineCallbacks def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) - url = yield defer.ensureDeferred( + url = self.get_success( self.handler.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) @@ -343,20 +354,18 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["nonce"], [nonce]) self.assertEqual(redirect, "http://client/redirect") - @defer.inlineCallbacks def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" self.handler._render_error = Mock() request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "") request.args[b"error_description"] = [b"some description"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "some description") - @defer.inlineCallbacks def test_callback(self): """Code callback works and display errors if something went wrong. @@ -377,7 +386,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "foo", "preferred_username": "bar", } - user_id = UserID("foo", "domain.org") + user_id = "@foo:domain.org" self.handler._render_error = Mock(return_value=None) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) @@ -394,13 +403,12 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - session = self.handler._generate_oidc_session_token( + request.getCookie.return_value = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request.getCookie.return_value = session request.args = {} request.args[b"code"] = [code.encode("utf-8")] @@ -410,10 +418,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -427,13 +435,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock( raises=MappingException() ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error") self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") self.handler._auth_handler.complete_sso_login.reset_mock() @@ -444,10 +452,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() @@ -459,17 +467,16 @@ class OidcHandlerTestCase(HomeserverTestCase): # Handle userinfo fetching error self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure self.handler._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") - @defer.inlineCallbacks def test_callback_session(self): """The callback verifies the session presence and validity""" self.handler._render_error = Mock(return_value=None) @@ -478,20 +485,20 @@ class OidcHandlerTestCase(HomeserverTestCase): # Missing cookie request.args = {} request.getCookie.return_value = None - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("missing_session", "No session cookie found") # Missing session parameter request.args = {} request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request", "State parameter is missing") # Invalid cookie request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_session") # Mismatching session @@ -504,18 +511,17 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args = {} request.args[b"state"] = [b"mismatching state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mismatching_session") # Valid session request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) - @defer.inlineCallbacks def test_exchange_code(self): """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} @@ -524,7 +530,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + ret = self.get_success(self.handler._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -546,10 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "foo") - self.assertEqual(exc.exception.error_description, "bar") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "foo") + self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body self.http_client.request = simple_async_mock( @@ -557,9 +562,8 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body self.http_client.request = simple_async_mock( @@ -569,17 +573,16 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "internal_server_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "internal_server_error") + + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field self.http_client.request = simple_async_mock( @@ -587,9 +590,62 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "some_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "some_error") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderExtra" + } + } + } + ) + def test_extra_attributes(self): + """ + Login while using a mapping provider that implements get_extra_attributes. + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "phone": "1234567", + } + user_id = "@foo:domain.org" + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) + + state = "state" + client_redirect_url = "http://client/redirect" + request.getCookie.return_value = self.handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + + request.args = {} + request.args[b"code"] = [b"code"] + request.args[b"state"] = [state.encode("utf-8")] + + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [b"Browser"] + request.getClientIP.return_value = "10.0.0.1" + + self.get_success(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, {"phone": "1234567"}, + ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" @@ -617,3 +673,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 5604af3795..212484a7fe 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase): r = self.successResultOf(d) self.assertEqual(r.code, 200) - def test_client_headers_no_body(self): + @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) + def test_timeout_reading_body(self, method_name: str): """ If the HTTP request is connected, but gets no response before being - timed out, it'll give a ResponseNeverReceived. + timed out, it'll give a RequestSendFailed with can_retry. """ - d = defer.ensureDeferred( - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) - ) + method = getattr(self.cl, method_name) + d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000)) self.pump() @@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase): self.reactor.advance(10.5) f = self.failureResultOf(d) - self.assertIsInstance(f.value, TimeoutError) + self.assertIsInstance(f.value, RequestSendFailed) + self.assertTrue(f.value.can_retry) + self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) def test_client_requires_trailing_slashes(self): """ diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py new file mode 100644 index 0000000000..a1cf0862d4 --- /dev/null +++ b/tests/http/test_simple_client.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from netaddr import IPSet + +from twisted.internet import defer +from twisted.internet.error import DNSLookupError + +from synapse.http import RequestTimedOutError +from synapse.http.client import SimpleHttpClient +from synapse.server import HomeServer + +from tests.unittest import HomeserverTestCase + + +class SimpleHttpClientTests(HomeserverTestCase): + def prepare(self, reactor, clock, hs: "HomeServer"): + # Add a DNS entry for a test server + self.reactor.lookups["testserv"] = "1.2.3.4" + + self.cl = hs.get_simple_http_client() + + def test_dns_error(self): + """ + If the DNS lookup returns an error, it will bubble up. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar")) + self.pump() + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, DNSLookupError) + + def test_client_connection_refused(self): + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + e = Exception("go away") + factory.clientConnectionFailed(None, e) + self.pump(0.5) + + f = self.failureResultOf(d) + + self.assertIs(f.value, e) + + def test_client_never_connect(self): + """ + If the HTTP request is not connected and is timed out, it'll give a + ConnectingCancelledError or TimeoutError. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], "1.2.3.4") + self.assertEqual(clients[0][1], 8008) + + # Deferred is still without a result + self.assertNoResult(d) + + # Push by enough to time it out + self.reactor.advance(120) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestTimedOutError) + + def test_client_connect_no_response(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], "1.2.3.4") + self.assertEqual(clients[0][1], 8008) + + conn = Mock() + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred is still without a result + self.assertNoResult(d) + + # Push by enough to time it out + self.reactor.advance(120) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestTimedOutError) + + def test_client_ip_range_blacklist(self): + """Ensure that Synapse does not try to connect to blacklisted IPs""" + + # Add some DNS entries we'll blacklist + self.reactor.lookups["internal"] = "127.0.0.1" + self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" + ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"]) + + cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist) + + # Try making a GET request to a blacklisted IPv4 address + # ------------------------------------------------------ + # Make the request + d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar")) + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + self.failureResultOf(d, DNSLookupError) + + # Try making a POST request to a blacklisted IPv6 address + # ------------------------------------------------------- + # Make the request + d = defer.ensureDeferred( + cl.post_json_get_json("http://internalv6:8008/foo/bar", {}) + ) + + # Move the reactor forwards + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + # Check that it was due to a blacklisted DNS lookup + self.failureResultOf(d, DNSLookupError) + + # Try making a GET request to a non-blacklisted IPv4 address + # ---------------------------------------------------------- + # Make the request + d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar")) + + # Nothing has happened yet + self.assertNoResult(d) + + # Move the reactor forwards + self.pump(1) + + # Check that it was able to resolve the address + clients = self.reactor.tcpClients + self.assertNotEqual(len(clients), 0) + + # Connection will still fail as this IP address does not resolve to anything + self.failureResultOf(d, RequestTimedOutError) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 04de0b9dbe..54600ad983 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -12,13 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase class ModuleApiTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler()) @@ -52,3 +59,50 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the displayname was assigned displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + + def test_public_rooms(self): + """Tests that a room can be added and removed from the public rooms list, + as well as have its public rooms directory state queried. + """ + # Create a user and room to play with + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + room_id = self.helper.create_room_as(user_id, tok=tok) + + # The room should not currently be in the public rooms directory + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) + + # Let's try adding it to the public rooms directory + self.get_success( + self.module_api.public_room_list_manager.add_room_to_public_room_list( + room_id + ) + ) + + # And checking whether it's in there... + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertTrue(is_in_public_rooms) + + # Let's remove it again + self.get_success( + self.module_api.public_room_list_manager.remove_room_from_public_room_list( + room_id + ) + ) + + # Should be gone + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ae60874ec3..81ea985b9f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Callable, List, Optional, Tuple import attr +import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel @@ -27,7 +28,7 @@ from synapse.app.generic_worker import ( GenericWorkerServer, ) from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() + # Fake in memory Redis server that servers can connect to. + self._redis_server = FakeRedisPubSubServer() + store = self.hs.get_datastore() self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["localhost"] = "127.0.0.1" - self._worker_hs_to_resource = {} + # A map from a HS instance to the associated HTTP Site to use for + # handling inbound HTTP requests to that instance. + self._hs_to_site = {self.hs: self.site} + + if self.hs.config.redis.redis_enabled: + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", 6379, self.connect_any_redis_attempts, + ) + + self.hs.get_tcp_replication().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't # manually have to go and explicitly set it up each time (plus sometimes # it is impossible to write the handling explicitly in the tests). + # + # Register the master replication listener: self.reactor.add_tcp_client_callback( - "1.2.3.4", 8765, self._handle_http_replication_attempt + "1.2.3.4", + 8765, + lambda: self._handle_http_replication_attempt(self.hs, 8765), ) def create_test_json_resource(self): @@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): **kwargs ) + # If the instance is in the `instance_map` config then workers may try + # and send HTTP requests to it, so we register it with + # `_handle_http_replication_attempt` like we do with the master HS. + instance_name = worker_hs.get_instance_name() + instance_loc = worker_hs.config.worker.instance_map.get(instance_name) + if instance_loc: + # Ensure the host is one that has a fake DNS entry. + if instance_loc.host not in self.reactor.lookups: + raise Exception( + "Host does not have an IP for instance_map[%r].host = %r" + % (instance_name, instance_loc.host,) + ) + + self.reactor.add_tcp_client_callback( + self.reactor.lookups[instance_loc.host], + instance_loc.port, + lambda: self._handle_http_replication_attempt( + worker_hs, instance_loc.port + ), + ) + store = worker_hs.get_datastore() store.db_pool._db_pool = self.database_pool._db_pool - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) + # Set up TCP replication between master and the new worker if we don't + # have Redis support enabled. + if not worker_hs.config.redis_enabled: + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) # Set up a resource for the worker - resource = ReplicationRestResource(self.hs) + resource = ReplicationRestResource(worker_hs) for servlet in self.servlets: servlet(worker_hs, resource) - self._worker_hs_to_resource[worker_hs] = resource + self._hs_to_site[worker_hs] = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="{}-{}".format( + worker_hs.config.server.server_name, worker_hs.get_instance_name() + ), + config=worker_hs.config.server.listeners[0], + resource=resource, + server_version_string="1", + ) + + if worker_hs.config.redis.redis_enabled: + worker_hs.get_tcp_replication().start_replication(worker_hs) return worker_hs @@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return config def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + render(request, self._hs_to_site[worker_hs].resource, self.reactor) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self): - """Handles a connection attempt to the master replication HTTP - listener. + def _handle_http_replication_attempt(self, hs, repl_port): + """Handles a connection attempt to the given HS replication HTTP + listener on the given port. """ # We should have at least one outbound connection attempt, where the @@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") - self.assertEqual(port, 8765) + self.assertEqual(port, repl_port) # Set up client side protocol client_protocol = client_factory.buildProtocol(None) @@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol channel = _PushHTTPChannel(self.reactor) channel.requestFactory = request_factory - channel.site = self.site + channel.site = self._hs_to_site[hs] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. + def connect_any_redis_attempts(self): + """If redis is enabled we need to deal with workers connecting to a + redis server. We don't want to use a real Redis server so we use a + fake one. + """ + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) + + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) + + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) + + return client_to_server_transport, server_to_client_transport + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -467,3 +547,105 @@ class _PullToPushProducer: pass self.stopProducing() + + +class FakeRedisPubSubServer: + """A fake Redis server for pub/sub. + """ + + def __init__(self): + self._subscribers = set() + + def add_subscriber(self, conn): + """A connection has called SUBSCRIBE + """ + self._subscribers.add(conn) + + def remove_subscriber(self, conn): + """A connection has called UNSUBSCRIBE + """ + self._subscribers.discard(conn) + + def publish(self, conn, channel, msg) -> int: + """A connection want to publish a message to subscribers. + """ + for sub in self._subscribers: + sub.send(["message", channel, msg]) + + return len(self._subscribers) + + def buildProtocol(self, addr): + return FakeRedisPubSubProtocol(self) + + +class FakeRedisPubSubProtocol(Protocol): + """A connection from a client talking to the fake Redis server. + """ + + def __init__(self, server: FakeRedisPubSubServer): + self._server = server + self._reader = hiredis.Reader() + + def dataReceived(self, data): + self._reader.feed(data) + + # We might get multiple messages in one packet. + while True: + msg = self._reader.gets() + + if msg is False: + # No more messages. + return + + if not isinstance(msg, list): + # Inbound commands should always be a list + raise Exception("Expected redis list") + + self.handle_command(msg[0], *msg[1:]) + + def handle_command(self, command, *args): + """Received a Redis command from the client. + """ + + # We currently only support pub/sub. + if command == b"PUBLISH": + channel, message = args + num_subscribers = self._server.publish(self, channel, message) + self.send(num_subscribers) + elif command == b"SUBSCRIBE": + (channel,) = args + self._server.add_subscriber(self) + self.send(["subscribe", channel, 1]) + else: + raise Exception("Unknown command") + + def send(self, msg): + """Send a message back to the client. + """ + raw = self.encode(msg).encode("utf-8") + + self.transport.write(raw) + self.transport.flush() + + def encode(self, obj): + """Encode an object to its Redis format. + + Supports: strings/bytes, integers and list/tuples. + """ + + if isinstance(obj, bytes): + # We assume bytes are just unicode strings. + obj = obj.decode("utf-8") + + if isinstance(obj, str): + return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + if isinstance(obj, int): + return ":{val}\r\n".format(val=obj) + if isinstance(obj, (list, tuple)): + items = "".join(self.encode(a) for a in obj) + return "*{len}\r\n{items}".format(len=len(obj), items=items) + + raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) + + def connectionLost(self, reason): + self._server.remove_subscriber(self) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py new file mode 100644 index 0000000000..6068d14905 --- /dev/null +++ b/tests/replication/test_sharded_event_persister.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + +logger = logging.getLogger(__name__) + + +class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks event persisting sharding works + """ + + # Event persister sharding requires postgres (due to needing + # `MutliWriterIdGenerator`). + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["redis"] = {"enabled": "true"} + conf["stream_writers"] = {"events": ["worker1", "worker2"]} + conf["instance_map"] = { + "worker1": {"host": "testserv", "port": 1001}, + "worker2": {"host": "testserv", "port": 1002}, + } + return conf + + def test_basic(self): + """Simple test to ensure that multiple rooms can be created and joined, + and that different rooms get handled by different instances. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + persisted_on_1 = False + persisted_on_2 = False + + store = self.hs.get_datastore() + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Keep making new rooms until we see rooms being persisted on both + # workers. + for _ in range(10): + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = rseponse["event_id"] + + # The event position includes which instance persisted the event. + pos = self.get_success(store.get_position_for_event(event_id)) + + persisted_on_1 |= pos.instance_name == "worker1" + persisted_on_2 |= pos.instance_name == "worker2" + + if persisted_on_1 and persisted_on_2: + break + + self.assertTrue(persisted_on_1) + self.assertTrue(persisted_on_2) diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py index 8c24add530..715e87de08 100644 --- a/tests/rest/client/third_party_rules.py +++ b/tests/rest/client/third_party_rules.py @@ -12,18 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.types import Requester from tests import unittest class ThirdPartyRulesTestModule: - def __init__(self, config): + def __init__(self, config, *args, **kwargs): pass - def check_event_allowed(self, event, context): + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return True + + async def check_event_allowed(self, event, context): if event.type == "foo.bar.forbidden": return False else: @@ -51,29 +56,31 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) return self.hs + def prepare(self, reactor, clock, homeserver): + # Create a user and room to play with during the tests + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - room_id = self.helper.create_room_as(user_id, tok=tok) - request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id, + "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, {}, - access_token=tok, + access_token=self.tok, ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id, + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id, {}, - access_token=tok, + access_token=self.tok, ) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0a567b032f..0d809d25d5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -905,6 +905,7 @@ class RoomMessageListTestCase(RoomBase): first_token = self.get_success( store.get_topological_token_for_event(first_event_id) ) + first_token_str = self.get_success(first_token.to_string(store)) # Send a second message in the room, which won't be removed, and which we'll # use as the marker to purge events before. @@ -912,6 +913,7 @@ class RoomMessageListTestCase(RoomBase): second_token = self.get_success( store.get_topological_token_for_event(second_event_id) ) + second_token_str = self.get_success(second_token.to_string(store)) # Send a third event in the room to ensure we don't fall under any edge case # due to our marker being the latest forward extremity in the room. @@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase): pagination_handler._purge_history( purge_id=purge_id, room_id=self.room_id, - token=second_token, + token=second_token_str, delete_local_events=True, ) ) @@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 93f899d861..ae2cd67f35 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -732,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) def test_next_link_domain_whitelist(self): """Tests next_link parameters must fit the whitelist if provided""" + + # Ensure not providing a next_link parameter still works + self._request_token( + "something@example.com", "some_secret", next_link=None, expect_code=200, + ) + self._request_token( "something@example.com", "some_secret", diff --git a/tests/server.py b/tests/server.py index b404ad4e2a..f7f5276b21 100644 --- a/tests/server.py +++ b/tests/server.py @@ -372,6 +372,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.threadpool = ThreadPool(clock._reactor) pool.running = True + # We've just changed the Databases to run DB transactions on the same + # thread, so we need to disable the dedicated thread behaviour. + server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 46f94914ff..c905a38930 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): # must be done after inserts database = hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) def tearDown(self): @@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): db_config = hs.config.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine), hs + database, make_conn(db_config, self.engine, "test"), hs ) def _add_service(self, url, as_token, id): @@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) @defer.inlineCallbacks @@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception @@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 949846fe33..3957471f3f 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase): self.reactor.advance(60 * 60 * 1000) self.pump(1) - items = set( + items = list( filter( lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY).split(b"\n"), + generate_latest(REGISTRY, emit_help=False).split(b"\n"), ) ) - expected = { + expected = [ b'synapse_forward_extremities_bucket{le="1.0"} 0.0', b'synapse_forward_extremities_bucket{le="2.0"} 2.0', b'synapse_forward_extremities_bucket{le="3.0"} 2.0', @@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase): b'synapse_forward_extremities_bucket{le="100.0"} 3.0', b'synapse_forward_extremities_bucket{le="200.0"} 3.0', b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - } - + # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9, + # "inf" is valid: "this includes variants such as inf" + b'synapse_forward_extremities_bucket{le="inf"} 3.0', + b"# TYPE synapse_forward_extremities_gcount gauge", + b"synapse_forward_extremities_gcount 3.0", + b"# TYPE synapse_forward_extremities_gsum gauge", + b"synapse_forward_extremities_gsum 10.0", + ] self.assertEqual(items, expected) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index d4ff55fbff..392b08832b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from synapse.storage.database import DatabasePool +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): writers=writers, ) - return self.get_success(self.db_pool.runWithConnection(_create)) + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): """Insert N rows as the given instance, inserting with stream IDs pulled @@ -391,17 +390,28 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Initial config has two writers id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(id_gen.get_current_token_for_writer("second"), 5) # New config removes one of the configs. Note that if the writer is # removed from config we assume that it has been shut down and has # finished persisting, hence why the persisted upto position is 5. id_gen_2 = self._create_id_generator("second", writers=["second"]) self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5) # This config points to a single, previously unused writer. id_gen_3 = self._create_id_generator("third", writers=["third"]) self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer comes along. + self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5) + + id_gen_4 = self._create_id_generator("fourth", writers=["third"]) + self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5) + # Check that we get a sane next stream ID with this new config. async def _get_next_async(): @@ -411,6 +421,30 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async()) self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + # If we add back the old "first" then we shouldn't see the persisted up + # to position revert back to 3. + id_gen_5 = self._create_id_generator("five", writers=["first", "third"]) + self.assertEqual(id_gen_5.get_persisted_upto_position(), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) + + def test_sequence_consistency(self): + """Test that we error out if the table and sequence diverges. + """ + + # Prefill with some rows + self._insert_row_with_id("master", 3) + + # Now we add a row *without* updating the stream ID + def _insert(txn): + txn.execute("INSERT INTO foobar VALUES (26, 'master')") + + self.get_success(self.db_pool.runInteraction("_insert", _insert)) + + # Creating the ID gen should error + with self.assertRaises(IncorrectDatabaseSetup): + self._create_id_generator("first") + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 643072bbaf..8d97b6d4cd 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 1) + def test_appservice_user_not_counted_in_mau(self): + self.get_success( + self.store.register_user( + user_id="@appservice_user:server", appservice_id="wibble" + ) + ) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.upsert_monthly_active_user("@appservice_user:server") + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" @@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, 4) + self.assertEqual(count, 1) d = self.store.get_monthly_active_count_by_service() result = self.get_success(d) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 918387733b..cc1f3c53c5 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = self.get_success( + token = self.get_success( store.get_topological_token_for_event(last["event_id"]) ) + token_str = self.get_success(token.to_string(self.hs.get_datastore())) # Purge everything before this topological token - self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) + self.get_success( + storage.purge_events.purge_history(self.room_id, token_str, True) + ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. @@ -74,12 +77,10 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = self.get_success( + token = self.get_success( storage.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format( - *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) - ) + event = "t{}-{}".format(token.topological + 1, token.stream + 1) # Purge everything before this topological token purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py index 7657bddea5..e7aed092c2 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py @@ -17,7 +17,7 @@ import resource import mock -from synapse.app.homeserver import phone_stats_home +from synapse.app.phone_stats_home import phone_stats_home from tests.unittest import HomeserverTestCase diff --git a/tests/unittest.py b/tests/unittest.py index dabf69cff4..82ede9de34 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import gc import hashlib import hmac @@ -23,11 +22,12 @@ import logging import time from typing import Optional, Tuple, Type, TypeVar, Union -from mock import Mock +from mock import Mock, patch from canonicaljson import json from twisted.internet.defer import Deferred, ensureDeferred, succeed +from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -169,6 +169,19 @@ def INFO(target): return target +def logcontext_clean(target): + """A decorator which marks the TestCase or method as 'logcontext_clean' + + ... ie, any logcontext errors should cause a test failure + """ + + def logcontext_error(msg): + raise AssertionError("logcontext error: %s" % (msg)) + + patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) + return patcher(target) + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -228,7 +241,7 @@ class HomeserverTestCase(TestCase): # create a site to wrap the resource. self.site = SynapseSite( logger_name="synapse.access.http.fake", - site_tag="test", + site_tag=self.hs.config.server.server_name, config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", @@ -463,6 +476,35 @@ class HomeserverTestCase(TestCase): self.pump() return self.failureResultOf(d, exc) + def get_success_or_raise(self, d, by=0.0): + """Drive deferred to completion and return result or raise exception + on failure. + """ + + if inspect.isawaitable(d): + deferred = ensureDeferred(d) + if not isinstance(deferred, Deferred): + return d + + results = [] # type: list + deferred.addBoth(results.append) + + self.pump(by=by) + + if not results: + self.fail( + "Success result expected on {!r}, found no result instead".format( + deferred + ) + ) + + result = results[0] + + if isinstance(result, Failure): + result.raiseException() + + return result + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. diff --git a/tests/utils.py b/tests/utils.py index 4673872f88..af563ffe0f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -38,6 +38,7 @@ from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database from synapse.util.ratelimitutils import FederationRateLimiter @@ -88,6 +89,7 @@ def setupdb(): host=POSTGRES_HOST, password=POSTGRES_PASSWORD, ) + db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(db_conn, db_engine, None) db_conn.close() @@ -276,7 +278,7 @@ def setup_test_homeserver( hs.setup() if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() + hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): database = hs.get_datastores().databases[0]