Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

pull/12117/head
Erik Johnston 2022-02-22 14:36:44 +00:00
commit 3d92936c14
123 changed files with 2888 additions and 859 deletions

View File

@ -8,7 +8,9 @@ export DEBIAN_FRONTEND=noninteractive
set -ex set -ex
apt-get update apt-get update
apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev apt-get install -y \
python3 python3-dev python3-pip python3-venv \
libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev
export LANG="C.UTF-8" export LANG="C.UTF-8"

View File

@ -7,7 +7,7 @@ on:
# of things breaking (but only build one set of debs) # of things breaking (but only build one set of debs)
pull_request: pull_request:
push: push:
branches: ["develop"] branches: ["develop", "release-*"]
# we do the full build on tags. # we do the full build on tags.
tags: ["v*"] tags: ["v*"]
@ -91,17 +91,7 @@ jobs:
build-sdist: build-sdist:
name: "Build pypi distribution files" name: "Build pypi distribution files"
runs-on: ubuntu-latest uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1"
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install wheel
- run: |
python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v2
with:
name: python-dist
path: dist/*
# if it's a tag, create a release and attach the artifacts to it # if it's a tag, create a release and attach the artifacts to it
attach-assets: attach-assets:

View File

@ -48,24 +48,10 @@ jobs:
env: env:
PULL_REQUEST_NUMBER: ${{ github.event.number }} PULL_REQUEST_NUMBER: ${{ github.event.number }}
lint-sdist:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.x"
- run: pip install wheel
- run: python setup.py sdist bdist_wheel
- uses: actions/upload-artifact@v2
with:
name: Python Distributions
path: dist/*
# Dummy step to gate other tests on without repeating the whole list # Dummy step to gate other tests on without repeating the whole list
linting-done: linting-done:
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
needs: [lint, lint-crlf, lint-newsfile, lint-sdist] needs: [lint, lint-crlf, lint-newsfile]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- run: "true" - run: "true"
@ -397,7 +383,6 @@ jobs:
- lint - lint
- lint-crlf - lint-crlf
- lint-newsfile - lint-newsfile
- lint-sdist
- trial - trial
- trial-olddeps - trial-olddeps
- sytest - sytest

View File

@ -1,3 +1,9 @@
Synapse 1.53.0 (2022-02-22)
===========================
No significant changes since 1.53.0rc1.
Synapse 1.53.0rc1 (2022-02-15) Synapse 1.53.0rc1 (2022-02-15)
============================== ==============================
@ -5,7 +11,7 @@ Features
-------- --------
- Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966)) - Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966))
- Remove account data (including client config, push rules and ignored users) upon user deactivation. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) - Add a background database update to purge account data for deactivated users. ([\#11655](https://github.com/matrix-org/synapse/issues/11655))
- Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837)) - Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837))
- Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849)) - Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849))
- Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854)) - Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854))
@ -86,7 +92,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist
has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx)
within the Twisted library. We do not believe Synapse is affected by this vulnerability, within the Twisted library. We do not believe Synapse is affected by this vulnerability,
though we advise server administrators who installed Synapse via pip to upgrade Twisted though we advise server administrators who installed Synapse via pip to upgrade Twisted
with `pip install --upgrade Twisted` as a matter of good practice. The Docker image with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image
`matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the
updated library. updated library.
@ -267,7 +273,7 @@ Bugfixes
Synapse 1.50.0 (2022-01-18) Synapse 1.50.0 (2022-01-18)
=========================== ===========================
**This release contains a critical bug that may prevent clients from being able to connect. **This release contains a critical bug that may prevent clients from being able to connect.
As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to As such, it is not recommended to upgrade to 1.50.0. Instead, please upgrade straight to
to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).** to 1.50.1. Further details are available in [this issue](https://github.com/matrix-org/synapse/issues/11763).**

1
changelog.d/10870.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

1
changelog.d/11608.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

View File

@ -0,0 +1 @@
Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push.

1
changelog.d/11972.misc Normal file
View File

@ -0,0 +1 @@
Add tests for device list changes between local users.

1
changelog.d/11974.misc Normal file
View File

@ -0,0 +1 @@
Optimise calculating device_list changes in `/sync`.

1
changelog.d/11984.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -0,0 +1 @@
Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama.

1
changelog.d/11991.misc Normal file
View File

@ -0,0 +1 @@
Refactor the search code for improved readability.

1
changelog.d/11992.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.

1
changelog.d/11994.misc Normal file
View File

@ -0,0 +1 @@
Move common deduplication code down into `_auth_and_persist_outliers`.

1
changelog.d/11996.misc Normal file
View File

@ -0,0 +1 @@
Limit concurrent joins from applications services.

1
changelog.d/11997.docker Normal file
View File

@ -0,0 +1 @@
The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage.

1
changelog.d/11999.bugfix Normal file
View File

@ -0,0 +1 @@
Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room.

View File

@ -0,0 +1 @@
Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time.

1
changelog.d/12003.doc Normal file
View File

@ -0,0 +1 @@
Explain the meaning of spam checker callbacks' return values.

1
changelog.d/12004.doc Normal file
View File

@ -0,0 +1 @@
Clarify information about external Identity Provider IDs.

1
changelog.d/12005.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.

View File

@ -0,0 +1 @@
Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration).

View File

@ -0,0 +1 @@
Enable modules to set a custom display name when registering a user.

1
changelog.d/12011.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: parse msc3706 fields in send_join response.

1
changelog.d/12013.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.

1
changelog.d/12015.misc Normal file
View File

@ -0,0 +1 @@
Configure `tox` to use `venv` rather than `virtualenv`.

1
changelog.d/12016.misc Normal file
View File

@ -0,0 +1 @@
Fix bug in `StateFilter.return_expanded()` and add some tests.

View File

@ -0,0 +1 @@
Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported.

1
changelog.d/12019.misc Normal file
View File

@ -0,0 +1 @@
Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms.

View File

@ -0,0 +1 @@
Advertise Matrix 1.1 support on `/_matrix/client/versions`.

View File

@ -0,0 +1 @@
Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`.

View File

@ -0,0 +1 @@
Advertise Matrix 1.2 support on `/_matrix/client/versions`.

1
changelog.d/12024.bugfix Normal file
View File

@ -0,0 +1 @@
Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint.

1
changelog.d/12025.misc Normal file
View File

@ -0,0 +1 @@
Update the `olddeps` CI job to use an old version of `markupsafe`.

1
changelog.d/12030.misc Normal file
View File

@ -0,0 +1 @@
Upgrade mypy to version 0.931.

1
changelog.d/12033.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

1
changelog.d/12034.misc Normal file
View File

@ -0,0 +1 @@
Minor typing fixes.

1
changelog.d/12039.misc Normal file
View File

@ -0,0 +1 @@
Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.

1
changelog.d/12041.misc Normal file
View File

@ -0,0 +1 @@
After joining a room, create a dedicated logcontext to process the queued events.

1
changelog.d/12051.misc Normal file
View File

@ -0,0 +1 @@
Tidy up GitHub Actions config which builds distributions for PyPI.

1
changelog.d/12052.misc Normal file
View File

@ -0,0 +1 @@
Move `isort` configuration to `pyproject.toml`.

1
changelog.d/12056.bugfix Normal file
View File

@ -0,0 +1 @@
Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.

View File

@ -0,0 +1 @@
Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)).

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.53.0) stable; urgency=medium
* New synapse release 1.53.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 22 Feb 2022 11:32:06 +0000
matrix-synapse-py3 (1.53.0~rc1) stable; urgency=medium matrix-synapse-py3 (1.53.0~rc1) stable; urgency=medium
* New synapse release 1.53.0~rc1. * New synapse release 1.53.0~rc1.

View File

@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py COPY ./docker/start.py /start.py
COPY ./docker/conf /conf COPY ./docker/conf /conf
VOLUME ["/data"]
EXPOSE 8008/tcp 8009/tcp 8448/tcp EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"] ENTRYPOINT ["/start.py"]

View File

@ -126,7 +126,8 @@ Body parameters:
[Sample Configuration File](../usage/configuration/homeserver_sample_config.html) [Sample Configuration File](../usage/configuration/homeserver_sample_config.html)
section `sso` and `oidc_providers`. section `sso` and `oidc_providers`.
- `auth_provider` - string. ID of the external identity provider. Value of `idp_id` - `auth_provider` - string. ID of the external identity provider. Value of `idp_id`
in homeserver configuration. in the homeserver configuration. Note that no error is raised if the provided
value is not in the homeserver configuration.
- `external_id` - string, user ID in the external identity provider. - `external_id` - string, user ID in the external identity provider.
- `avatar_url` - string, optional, must be a - `avatar_url` - string, optional, must be a
[MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris).

View File

@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`, any of the subsequent implementations of this callback. If every callback returns `None`,
the authentication is denied. the authentication is denied.
### `on_logged_out` ### `on_logged_out`
@ -162,10 +162,38 @@ return `None`.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`, any of the subsequent implementations of this callback. If every callback returns `None`,
the username provided by the user is used, if any (otherwise one is automatically the username provided by the user is used, if any (otherwise one is automatically
generated). generated).
### `get_displayname_for_registration`
_First introduced in Synapse v1.54.0_
```python
async def get_displayname_for_registration(
uia_results: Dict[str, Any],
params: Dict[str, Any],
) -> Optional[str]
```
Called when registering a new user. The module can return a display name to set for the
user being registered by returning it as a string, or `None` if it doesn't wish to force a
display name for this user.
This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
has been completed by the user. It is not called when registering a user via SSO. It is
passed two dictionaries, which include the information that the user has provided during
the registration process. These dictionaries are identical to the ones passed to
[`get_username_for_registration`](#get_username_for_registration), so refer to the
documentation of this callback for more information about them.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback returns `None`,
the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
## `is_3pid_allowed` ## `is_3pid_allowed`
_First introduced in Synapse v1.53.0_ _First introduced in Synapse v1.53.0_
@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo
- Is checked by the method: `self.check_my_login` - Is checked by the method: `self.check_my_login`
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
- Expects a `password` field to be sent to `/login` - Expects a `password` field to be sent to `/login`
- Is checked by the method: `self.check_pass` - Is checked by the method: `self.check_pass`
```python ```python
from typing import Awaitable, Callable, Optional, Tuple from typing import Awaitable, Callable, Optional, Tuple

View File

@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_
async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str] async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str]
``` ```
Called when receiving an event from a client or via federation. The module can return Called when receiving an event from a client or via federation. The callback must return
either a `bool` to indicate whether the event must be rejected because of spam, or a `str` either:
to indicate the event must be rejected because of spam and to give a rejection reason to - an error message string, to indicate the event must be rejected because of spam and
forward to clients. give a rejection reason to forward to clients;
- the boolean `True`, to indicate that the event is spammy, but not provide further details; or
- the booelan `False`, to indicate that the event is not considered spammy.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first callback returns `False`, Synapse falls through to the next one. The value of the first
@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool
``` ```
Called when a user is trying to join a room. The module must return a `bool` to indicate Called when a user is trying to join a room. The module must return a `bool` to indicate
whether the user can join the room. The user is represented by their Matrix user ID (e.g. whether the user can join the room. Return `False` to prevent the user from joining the
room; otherwise return `True` to permit the joining.
The user is represented by their Matrix user ID (e.g.
`@alice:example.com`) and the room is represented by its Matrix ID (e.g. `@alice:example.com`) and the room is represented by its Matrix ID (e.g.
`!room:example.com`). The module is also given a boolean to indicate whether the user `!room:example.com`). The module is also given a boolean to indicate whether the user
currently has a pending invite in the room. currently has a pending invite in the room.
@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
Called when processing an invitation. The module must return a `bool` indicating whether Called when processing an invitation. The module must return a `bool` indicating whether
the inviter can invite the invitee to the given room. Both inviter and invitee are the inviter can invite the invitee to the given room. Both inviter and invitee are
represented by their Matrix user ID (e.g. `@alice:example.com`). represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent
the invitation; otherwise return `True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first callback returns `True`, Synapse falls through to the next one. The value of the first
@ -80,7 +86,8 @@ async def user_may_send_3pid_invite(
Called when processing an invitation using a third-party identifier (also called a 3PID, Called when processing an invitation using a third-party identifier (also called a 3PID,
e.g. an email address or a phone number). The module must return a `bool` indicating e.g. an email address or a phone number). The module must return a `bool` indicating
whether the inviter can invite the invitee to the given room. whether the inviter can invite the invitee to the given room. Return `False` to prevent
the invitation; otherwise return `True` to permit it.
The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the
invitee is represented by its medium (e.g. "email") and its address invitee is represented by its medium (e.g. "email") and its address
@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool
Called when processing a room creation request. The module must return a `bool` indicating Called when processing a room creation request. The module must return a `bool` indicating
whether the given user (represented by their Matrix user ID) is allowed to create a room. whether the given user (represented by their Matrix user ID) is allowed to create a room.
Return `False` to prevent room creation; otherwise return `True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first callback returns `True`, Synapse falls through to the next one. The value of the first
@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA
Called when trying to associate an alias with an existing room. The module must return a Called when trying to associate an alias with an existing room. The module must return a
`bool` indicating whether the given user (represented by their Matrix user ID) is allowed `bool` indicating whether the given user (represented by their Matrix user ID) is allowed
to set the given alias. to set the given alias. Return `False` to prevent the alias creation; otherwise return
`True` to permit it.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first callback returns `True`, Synapse falls through to the next one. The value of the first
@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool
Called when trying to publish a room to the homeserver's public rooms directory. The Called when trying to publish a room to the homeserver's public rooms directory. The
module must return a `bool` indicating whether the given user (represented by their module must return a `bool` indicating whether the given user (represented by their
Matrix user ID) is allowed to publish the given room. Matrix user ID) is allowed to publish the given room. Return `False` to prevent the
room from being published; otherwise return `True` to permit its publication.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first callback returns `True`, Synapse falls through to the next one. The value of the first
@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
``` ```
Called when computing search results in the user directory. The module must return a Called when computing search results in the user directory. The module must return a
`bool` indicating whether the given user profile can appear in search results. The profile `bool` indicating whether the given user should be excluded from user directory
is represented as a dictionary with the following keys: searches. Return `True` to indicate that the user is spammy and exclude them from
search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys:
* `user_id`: The Matrix ID for this user. * `user_id`: The Matrix ID for this user.
* `display_name`: The user's display name. * `display_name`: The user's display name.
@ -225,8 +238,9 @@ async def check_media_file_for_spam(
) -> bool ) -> bool
``` ```
Called when storing a local or remote file. The module must return a boolean indicating Called when storing a local or remote file. The module must return a `bool` indicating
whether the given file can be stored in the homeserver's media store. whether the given file should be excluded from the homeserver's media store. Return
`True` to prevent this file from being stored; otherwise return `False`.
If multiple modules implement this callback, they will be considered in order. If a If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first callback returns `False`, Synapse falls through to the next one. The value of the first

View File

@ -163,7 +163,7 @@ presence:
# For example, for room version 1, default_room_version should be set # For example, for room version 1, default_room_version should be set
# to "1". # to "1".
# #
#default_room_version: "6" #default_room_version: "9"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# #

View File

@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999.
## Upgrading from legacy structured logging configuration ## Upgrading from legacy structured logging configuration
Versions of Synapse prior to v1.23.0 included a custom structured logging Versions of Synapse prior to v1.54.0 automatically converted the legacy
configuration which is deprecated. It used a `structured: true` flag and structured logging configuration, which was deprecated in v1.23.0, to the standard
configured `drains` instead of ``handlers`` and `formatters`. library logging configuration.
Synapse currently automatically converts the old configuration to the new The following reference can be used to update your configuration. Based on the
configuration, but this will be removed in a future version of Synapse. The drain `type`, we can pick a new handler:
following reference can be used to update your configuration. Based on the drain
`type`, we can pick a new handler:
1. For a type of `console`, `console_json`, or `console_json_terse`: a handler 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler
with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout`

View File

@ -85,6 +85,15 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
``` ```
# Upgrading to v1.54.0
## Legacy structured logging configuration removal
This release removes support for the `structured: true` logging configuration
which was deprecated in Synapse v1.23.0. If your logging configuration contains
`structured: true` then it should be modified based on the
[structured logging documentation](structured_logging.md).
# Upgrading to v1.53.0 # Upgrading to v1.53.0
## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location` ## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location`
@ -157,7 +166,7 @@ Note that [Twisted 22.1.0](https://github.com/twisted/twisted/releases/tag/twist
has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx) has recently been released, which fixes a [security issue](https://github.com/twisted/twisted/security/advisories/GHSA-92x2-jw7w-xvvx)
within the Twisted library. We do not believe Synapse is affected by this vulnerability, within the Twisted library. We do not believe Synapse is affected by this vulnerability,
though we advise server administrators who installed Synapse via pip to upgrade Twisted though we advise server administrators who installed Synapse via pip to upgrade Twisted
with `pip install --upgrade Twisted` as a matter of good practice. The Docker image with `pip install --upgrade Twisted treq` as a matter of good practice. The Docker image
`matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the `matrixdotorg/synapse` and the Debian packages from `packages.matrix.org` are using the
updated library. updated library.

View File

@ -31,14 +31,11 @@ exclude = (?x)
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py |synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/api/test_auth.py |tests/api/test_auth.py

View File

@ -54,3 +54,15 @@ exclude = '''
)/ )/
) )
''' '''
[tool.isort]
line_length = 88
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"]
default_section = "THIRDPARTY"
known_first_party = ["synapse"]
known_tests = ["tests"]
known_twisted = ["twisted", "OpenSSL"]
multi_line_output = 3
include_trailing_comma = true
combine_as_imports = true

View File

@ -19,14 +19,3 @@ ignore =
# E731: do not assign a lambda expression, use a def # E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us) # E501: Line too long (black enforces this for us)
ignore=W503,W504,E203,E731,E501 ignore=W503,W504,E203,E731,E501
[isort]
line_length = 88
sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
default_section=THIRDPARTY
known_first_party = synapse
known_tests=tests
known_twisted=twisted,OpenSSL
multi_line_output=3
include_trailing_comma=true
combine_as_imports=true

View File

@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
] ]
CONDITIONAL_REQUIREMENTS["mypy"] = [ CONDITIONAL_REQUIREMENTS["mypy"] = [
"mypy==0.910", "mypy==0.931",
"mypy-zope==0.3.2", "mypy-zope==0.3.5",
"types-bleach>=4.1.0", "types-bleach>=4.1.0",
"types-jsonschema>=3.2.0", "types-jsonschema>=3.2.0",
"types-opentracing>=2.4.2", "types-opentracing>=2.4.2",

View File

@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]):
def __copy__(self: _SD) -> _SD: ... def __copy__(self: _SD) -> _SD: ...
@classmethod @classmethod
@overload @overload
def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... def fromkeys(
cls, seq: Iterable[_T_h], value: None = ...
) -> SortedDict[_T_h, None]: ...
@classmethod @classmethod
@overload @overload
def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ...
def keys(self) -> SortedKeysView[_KT]: ... # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so
def items(self) -> SortedItemsView[_KT, _VT]: ... # `Sorted{Keys,Items,Values}View` are no longer compatible with them.
def values(self) -> SortedValuesView[_VT]: ... # See https://github.com/python/typeshed/issues/6837
def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override]
def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override]
def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override]
@overload @overload
def pop(self, key: _KT) -> _VT: ... def pop(self, key: _KT) -> _VT: ...
@overload @overload

View File

@ -47,7 +47,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.53.0rc1" __version__ = "1.53.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -41,9 +41,6 @@ class ExperimentalConfig(Config):
# MSC3244 (room version capabilities) # MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
# MSC3283 (set displayname, avatar_url and change 3pid capabilities)
self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
# MSC3266 (room summary api) # MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
@ -64,3 +61,7 @@ class ExperimentalConfig(Config):
# MSC3706 (server-side support for partial state in /send_join responses) # MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
# experimental support for faster joins over federation (msc2775, msc3706)
# requires a target server with msc3706_enabled enabled.
self.faster_joins_enabled: bool = experimental.get("faster_joins", False)

View File

@ -33,7 +33,6 @@ from twisted.logger import (
globalLogBeginner, globalLogBeginner,
) )
from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter from synapse.logging.filter import MetadataFilter
@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option
removed in Synapse 1.3.0. You should instead set up a separate log configuration file. removed in Synapse 1.3.0. You should instead set up a separate log configuration file.
""" """
STRUCTURED_ERROR = """\
Support for the structured configuration option was removed in Synapse 1.54.0.
You should instead use the standard logging configuration. See
https://matrix-org.github.io/synapse/v1.54/structured_logging.html
"""
class LoggingConfig(Config): class LoggingConfig(Config):
section = "logging" section = "logging"
@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None:
if not log_config: if not log_config:
logging.warning("Loaded a blank logging config?") logging.warning("Loaded a blank logging config?")
# If the old structured logging configuration is being used, convert it to # If the old structured logging configuration is being used, raise an error.
# the new style configuration.
if "structured" in log_config and log_config.get("structured"): if "structured" in log_config and log_config.get("structured"):
log_config = setup_structured_logging(log_config) raise ConfigError(STRUCTURED_ERROR)
logging.config.dictConfig(log_config) logging.config.dictConfig(log_config)

View File

@ -146,7 +146,7 @@ DEFAULT_IP_RANGE_BLACKLIST = [
"fec0::/10", "fec0::/10",
] ]
DEFAULT_ROOM_VERSION = "6" DEFAULT_ROOM_VERSION = "9"
ROOM_COMPLEXITY_TOO_GREAT = ( ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. " "Your homeserver is unable to join rooms this large or complex. "

View File

@ -425,6 +425,33 @@ class EventClientSerializer:
return serialized_event return serialized_event
def _apply_edit(
self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
) -> None:
"""Replace the content, preserving existing relations of the serialized event.
Args:
orig_event: The original event.
serialized_event: The original event, serialized. This is modified.
edit: The event which edits the above.
"""
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
edit_content = edit.content.copy()
# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})
# Check for existing relations
relates_to = orig_event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
def _inject_bundled_aggregations( def _inject_bundled_aggregations(
self, self,
event: EventBase, event: EventBase,
@ -450,26 +477,11 @@ class EventClientSerializer:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
if aggregations.replace: if aggregations.replace:
# If there is an edit replace the content, preserving existing # If there is an edit, apply it to the event.
# relations.
edit = aggregations.replace edit = aggregations.replace
self._apply_edit(event, serialized_event, edit)
# Ensure we take copies of the edit content, otherwise we risk modifying # Include information about it in the relations dict.
# the original event.
edit_content = edit.content.copy()
# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})
# Check for existing relations
relates_to = event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
serialized_aggregations[RelationTypes.REPLACE] = { serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id, "event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts, "origin_server_ts": edit.origin_server_ts,
@ -478,13 +490,22 @@ class EventClientSerializer:
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if aggregations.thread: if aggregations.thread:
thread = aggregations.thread
# Don't bundle aggregations as this could recurse forever.
serialized_latest_event = self.serialize_event(
thread.latest_event, time_now, bundle_aggregations=None
)
# Manually apply an edit, if one exists.
if thread.latest_edit:
self._apply_edit(
thread.latest_event, serialized_latest_event, thread.latest_edit
)
serialized_aggregations[RelationTypes.THREAD] = { serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever. "latest_event": serialized_latest_event,
"latest_event": self.serialize_event( "count": thread.count,
aggregations.thread.latest_event, time_now, bundle_aggregations=None "current_user_participated": thread.current_user_participated,
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
} }
# Include the bundled aggregations in the event. # Include the bundled aggregations in the event.

View File

@ -47,6 +47,11 @@ class FederationBase:
) -> EventBase: ) -> EventBase:
"""Checks that event is correctly signed by the sending server. """Checks that event is correctly signed by the sending server.
Also checks the content hash, and redacts the event if there is a mismatch.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
Args: Args:
room_version: The room version of the PDU room_version: The room version of the PDU
pdu: the event to be checked pdu: the event to be checked
@ -55,7 +60,10 @@ class FederationBase:
* the original event if the checks pass * the original event if the checks pass
* a redacted version of the event (if the signature * a redacted version of the event (if the signature
matched but the hash did not) matched but the hash did not)
* throws a SynapseError if the signature check failed."""
Raises:
SynapseError if the signature check failed.
"""
try: try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu) await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except SynapseError as e: except SynapseError as e:

View File

@ -1,4 +1,4 @@
# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # Copyright 2015-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome # Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -89,6 +89,12 @@ class SendJoinResult:
state: List[EventBase] state: List[EventBase]
auth_chain: List[EventBase] auth_chain: List[EventBase]
# True if 'state' elides non-critical membership events
partial_state: bool
# if 'partial_state' is set, a list of the servers in the room (otherwise empty)
servers_in_room: List[str]
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -413,26 +419,90 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids return state_event_ids, auth_event_ids
async def get_room_state(
self,
destination: str,
room_id: str,
event_id: str,
room_version: RoomVersion,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Calls the /state endpoint to fetch the state at a particular point
in the room.
Any invalid events (those with incorrect or unverifiable signatures or hashes)
are filtered out from the response, and any duplicate events are removed.
(Size limits and other event-format checks are *not* performed.)
Note that the result is not ordered, so callers must be careful to process
the events in an order that handles dependencies.
Returns:
a tuple of (state events, auth events)
"""
result = await self.transport_layer.get_room_state(
room_version,
destination,
room_id,
event_id,
)
state_events = result.state
auth_events = result.auth_events
# we may as well filter out any duplicates from the response, to save
# processing them multiple times. (In particular, events may be present in
# `auth_events` as well as `state`, which is redundant).
#
# We don't rely on the sort order of the events, so we can just stick them
# in a dict.
state_event_map = {event.event_id: event for event in state_events}
auth_event_map = {
event.event_id: event
for event in auth_events
if event.event_id not in state_event_map
}
logger.info(
"Processing from /state: %d state events, %d auth events",
len(state_event_map),
len(auth_event_map),
)
valid_auth_events = await self._check_sigs_and_hash_and_fetch(
destination, auth_event_map.values(), room_version
)
valid_state_events = await self._check_sigs_and_hash_and_fetch(
destination, state_event_map.values(), room_version
)
return valid_state_events, valid_auth_events
async def _check_sigs_and_hash_and_fetch( async def _check_sigs_and_hash_and_fetch(
self, self,
origin: str, origin: str,
pdus: Collection[EventBase], pdus: Collection[EventBase],
room_version: RoomVersion, room_version: RoomVersion,
) -> List[EventBase]: ) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each """Checks the signatures and hashes of a list of events.
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of If a PDU fails its signature check then we check if we have it in
that PDU. the database, and if not then request it from the sender's server (if that
is different from `origin`). If that still fails, the event is omitted from
the returned list.
If a PDU fails its content hash check then it is redacted. If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns Also runs each event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
The given list of PDUs are not modified; instead the function returns
a new list. a new list.
Args: Args:
origin origin: The server that sent us these events
pdu pdus: The events to be checked
room_version room_version: the version of the room these events are in
Returns: Returns:
A list of PDUs that have valid signatures and hashes. A list of PDUs that have valid signatures and hashes.
@ -463,11 +533,16 @@ class FederationClient(FederationBase):
origin: str, origin: str,
room_version: RoomVersion, room_version: RoomVersion,
) -> Optional[EventBase]: ) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. If the PDU fails """Takes a PDU and checks its signatures and hashes.
its signature check then we check if we have it in the database and if
not then request if from the originating server of that PDU.
If then PDU fails its content hash check then it is redacted. If the PDU fails its signature check then we check if we have it in the
database; if not, we then request it from sender's server (if that is not the
same as `origin`). If that still fails, we return None.
If the PDU fails its content hash check, it is redacted.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
Args: Args:
origin origin
@ -864,23 +939,32 @@ class FederationClient(FederationBase):
for s in signed_state: for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata) s.internal_metadata = copy.deepcopy(s.internal_metadata)
# double-check that the same create event has ended up in the auth chain # double-check that the auth chain doesn't include a different create event
auth_chain_create_events = [ auth_chain_create_events = [
e.event_id e.event_id
for e in signed_auth for e in signed_auth
if (e.type, e.state_key) == (EventTypes.Create, "") if (e.type, e.state_key) == (EventTypes.Create, "")
] ]
if auth_chain_create_events != [create_event.event_id]: if auth_chain_create_events and auth_chain_create_events != [
create_event.event_id
]:
raise InvalidResponseError( raise InvalidResponseError(
"Unexpected create event(s) in auth chain: %s" "Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,) % (auth_chain_create_events,)
) )
if response.partial_state and not response.servers_in_room:
raise InvalidResponseError(
"partial_state was set, but no servers were listed in the room"
)
return SendJoinResult( return SendJoinResult(
event=event, event=event,
state=signed_state, state=signed_state,
auth_chain=signed_auth, auth_chain=signed_auth,
origin=destination, origin=destination,
partial_state=response.partial_state,
servers_in_room=response.servers_in_room or [],
) )
# MSC3083 defines additional error codes for room joins. # MSC3083 defines additional error codes for room joins.

View File

@ -381,7 +381,9 @@ class PerDestinationQueue:
) )
) )
if self._last_successful_stream_ordering is None: last_successful_stream_ordering = self._last_successful_stream_ordering
if last_successful_stream_ordering is None:
# if it's still None, then this means we don't have the information # if it's still None, then this means we don't have the information
# in our database ­ we haven't successfully sent a PDU to this server # in our database ­ we haven't successfully sent a PDU to this server
# (at least since the introduction of the feature tracking # (at least since the introduction of the feature tracking
@ -394,8 +396,7 @@ class PerDestinationQueue:
# get at most 50 catchup room/PDUs # get at most 50 catchup room/PDUs
while True: while True:
event_ids = await self._store.get_catch_up_room_event_ids( event_ids = await self._store.get_catch_up_room_event_ids(
self._destination, self._destination, last_successful_stream_ordering
self._last_successful_stream_ordering,
) )
if not event_ids: if not event_ids:
@ -403,7 +404,7 @@ class PerDestinationQueue:
# of a race condition, so we check that no new events have been # of a race condition, so we check that no new events have been
# skipped due to us being in catch-up mode # skipped due to us being in catch-up mode
if self._catchup_last_skipped > self._last_successful_stream_ordering: if self._catchup_last_skipped > last_successful_stream_ordering:
# another event has been skipped because we were in catch-up mode # another event has been skipped because we were in catch-up mode
continue continue
@ -470,7 +471,7 @@ class PerDestinationQueue:
# offline # offline
if ( if (
p.internal_metadata.stream_ordering p.internal_metadata.stream_ordering
< self._last_successful_stream_ordering < last_successful_stream_ordering
): ):
continue continue
@ -513,12 +514,11 @@ class PerDestinationQueue:
# from the *original* PDU, rather than the PDU(s) we actually # from the *original* PDU, rather than the PDU(s) we actually
# send. This is because we use it to mark our position in the # send. This is because we use it to mark our position in the
# queue of missed PDUs to process. # queue of missed PDUs to process.
self._last_successful_stream_ordering = ( last_successful_stream_ordering = pdu.internal_metadata.stream_ordering
pdu.internal_metadata.stream_ordering
)
self._last_successful_stream_ordering = last_successful_stream_ordering
await self._store.set_destination_last_successful_stream_ordering( await self._store.set_destination_last_successful_stream_ordering(
self._destination, self._last_successful_stream_ordering self._destination, last_successful_stream_ordering
) )
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:

View File

@ -1,4 +1,4 @@
# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome # Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -60,17 +60,17 @@ class TransportLayerClient:
def __init__(self, hs): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
async def get_room_state_ids( async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str self, destination: str, room_id: str, event_id: str
) -> JsonDict: ) -> JsonDict:
"""Requests all state for a given room from the given server at the """Requests the IDs of all state for a given room at the given event.
given event. Returns the state's event_id's
Args: Args:
destination: The host name of the remote homeserver we want destination: The host name of the remote homeserver we want
to get the state from. to get the state from.
context: The name of the context we want the state of room_id: the room we want the state of
event_id: The event we want the context at. event_id: The event we want the context at.
Returns: Returns:
@ -86,6 +86,29 @@ class TransportLayerClient:
try_trailing_slash_on_400=True, try_trailing_slash_on_400=True,
) )
async def get_room_state(
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> "StateRequestResponse":
"""Requests the full state for a given room at the given event.
Args:
room_version: the version of the room (required to build the event objects)
destination: The host name of the remote homeserver we want
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
Results in a dict received from the remote homeserver.
"""
path = _create_v1_path("/state/%s", room_id)
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
parser=_StateParser(room_version),
)
async def get_event( async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict: ) -> JsonDict:
@ -336,10 +359,15 @@ class TransportLayerClient:
content: JsonDict, content: JsonDict,
) -> "SendJoinResponse": ) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id) path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
if self._faster_joins_enabled:
# lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = "true"
return await self.client.put_json( return await self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
args=query_params,
data=content, data=content,
parser=SendJoinParser(room_version, v1_api=False), parser=SendJoinParser(room_version, v1_api=False),
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
@ -1271,6 +1299,20 @@ class SendJoinResponse:
# "event" is not included in the response. # "event" is not included in the response.
event: Optional[EventBase] = None event: Optional[EventBase] = None
# The room state is incomplete
partial_state: bool = False
# List of servers in the room
servers_in_room: Optional[List[str]] = None
@attr.s(slots=True, auto_attribs=True)
class StateRequestResponse:
"""The parsed response of a `/state` request."""
auth_events: List[EventBase]
state: List[EventBase]
@ijson.coroutine @ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
@ -1297,6 +1339,32 @@ def _event_list_parser(
events.append(event) events.append(event)
@ijson.coroutine
def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`
Parses the partial_state field in send_join responses
"""
while True:
val = yield
if not isinstance(val, bool):
raise TypeError("partial_state must be a boolean")
response.partial_state = val
@ijson.coroutine
def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
"""Helper function for use with `ijson.items_coro`
Parses the servers_in_room field in send_join responses
"""
while True:
val = yield
if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
raise TypeError("servers_in_room must be a list of strings")
response.servers_in_room = val
class SendJoinParser(ByteParser[SendJoinResponse]): class SendJoinParser(ByteParser[SendJoinResponse]):
"""A parser for the response to `/send_join` requests. """A parser for the response to `/send_join` requests.
@ -1308,44 +1376,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json" CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool): def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], {}) self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version self._room_version = room_version
self._coros = []
# The V1 API has the shape of `[200, {...}]`, which we handle by # The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`. # prefixing with `item.*`.
prefix = "item." if v1_api else "" prefix = "item." if v1_api else ""
self._coro_state = ijson.items_coro( self._coros = [
_event_list_parser(room_version, self._response.state), ijson.items_coro(
prefix + "state.item", _event_list_parser(room_version, self._response.state),
use_float=True, prefix + "state.item",
) use_float=True,
self._coro_auth = ijson.items_coro( ),
_event_list_parser(room_version, self._response.auth_events), ijson.items_coro(
prefix + "auth_chain.item", _event_list_parser(room_version, self._response.auth_events),
use_float=True, prefix + "auth_chain.item",
) use_float=True,
# TODO Remove the unstable prefix when servers have updated. ),
# # TODO Remove the unstable prefix when servers have updated.
# By re-using the same event dictionary this will cause the parsing of #
# org.matrix.msc3083.v2.event and event to stomp over each other. # By re-using the same event dictionary this will cause the parsing of
# Generally this should be fine. # org.matrix.msc3083.v2.event and event to stomp over each other.
self._coro_unstable_event = ijson.kvitems_coro( # Generally this should be fine.
_event_parser(self._response.event_dict), ijson.kvitems_coro(
prefix + "org.matrix.msc3083.v2.event", _event_parser(self._response.event_dict),
use_float=True, prefix + "org.matrix.msc3083.v2.event",
) use_float=True,
self._coro_event = ijson.kvitems_coro( ),
_event_parser(self._response.event_dict), ijson.kvitems_coro(
prefix + "event", _event_parser(self._response.event_dict),
use_float=True, prefix + "event",
) use_float=True,
),
]
if not v1_api:
self._coros.append(
ijson.items_coro(
_partial_state_parser(self._response),
"org.matrix.msc3706.partial_state",
use_float="True",
)
)
self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"org.matrix.msc3706.servers_in_room",
use_float="True",
)
)
def write(self, data: bytes) -> int: def write(self, data: bytes) -> int:
self._coro_state.send(data) for c in self._coros:
self._coro_auth.send(data) c.send(data)
self._coro_unstable_event.send(data)
self._coro_event.send(data)
return len(data) return len(data)
@ -1355,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._response.event_dict, self._room_version self._response.event_dict, self._room_version
) )
return self._response return self._response
class _StateParser(ByteParser[StateRequestResponse]):
"""A parser for the response to `/state` requests.
Args:
room_version: The version of the room.
"""
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
"auth_chain.item",
use_float=True,
),
]
def write(self, data: bytes) -> int:
for c in self._coros:
c.send(data)
return len(data)
def finish(self) -> StateRequestResponse:
return self._response

View File

@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict], [JsonDict, JsonDict],
Awaitable[Optional[str]], Awaitable[Optional[str]],
] ]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
self.get_username_for_registration_callbacks: List[ self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = [] ] = []
self.get_displayname_for_registration_callbacks: List[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters # Mapping from login type to login parameters
@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
get_username_for_registration: Optional[ get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None, ] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None: ) -> None:
# Register check_3pid_auth callback # Register check_3pid_auth callback
if check_3pid_auth is not None: if check_3pid_auth is not None:
@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
get_username_for_registration, get_username_for_registration,
) )
if get_displayname_for_registration is not None:
self.get_displayname_for_registration_callbacks.append(
get_displayname_for_registration,
)
if is_3pid_allowed is not None: if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed) self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
return None return None
async def get_displayname_for_registration(
self,
uia_results: JsonDict,
params: JsonDict,
) -> Optional[str]:
"""Defines the display name to use when registering the user, using the
credentials and parameters provided during the UIA flow.
Stops at the first callback that returns a tuple containing at least one string.
Args:
uia_results: The credentials provided during the UIA flow.
params: The parameters provided by the registration request.
Returns:
A tuple which first element is the display name, and the second is an MXC URL
to the user's avatar.
"""
for callback in self.get_displayname_for_registration_callbacks:
try:
res = await callback(uia_results, params)
if isinstance(res, str):
return res
elif res is not None:
# mypy complains that this line is unreachable because it assumes the
# data returned by the module fits the expected type. We just want
# to make sure this is the case.
logger.warning( # type: ignore[unreachable]
"Ignoring non-string value returned by"
" get_displayname_for_registration callback %s: %s",
callback,
res,
)
except Exception as e:
logger.error(
"Module raised an exception in get_displayname_for_registration: %s",
e,
)
raise SynapseError(code=500, msg="Internal Server Error")
return None
async def is_3pid_allowed( async def is_3pid_allowed(
self, self,
medium: str, medium: str,

View File

@ -49,8 +49,8 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
nested_logging_context, nested_logging_context,
preserve_fn, preserve_fn,
run_in_background,
) )
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.federation import ( from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet, ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet,
@ -516,7 +516,7 @@ class FederationHandler:
await self.store.upsert_room_on_join( await self.store.upsert_room_on_join(
room_id=room_id, room_id=room_id,
room_version=room_version_obj, room_version=room_version_obj,
auth_events=auth_chain, state_events=state,
) )
max_stream_id = await self._federation_event_handler.process_remote_join( max_stream_id = await self._federation_event_handler.process_remote_join(
@ -559,7 +559,9 @@ class FederationHandler:
# lots of requests for missing prev_events which we do actually # lots of requests for missing prev_events which we do actually
# have. Hence we fire off the background task, but don't wait for it. # have. Hence we fire off the background task, but don't wait for it.
run_in_background(self._handle_queued_pdus, room_queue) run_as_background_process(
"handle_queued_pdus", self._handle_queued_pdus, room_queue
)
async def do_knock( async def do_knock(
self, self,

View File

@ -419,10 +419,8 @@ class FederationEventHandler:
Raises: Raises:
SynapseError if the response is in some way invalid. SynapseError if the response is in some way invalid.
""" """
event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
create_event = None create_event = None
for e in auth_events: for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""): if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e create_event = e
break break
@ -439,11 +437,6 @@ class FederationEventHandler:
if room_version.identifier != room_version_id: if room_version.identifier != room_version_id:
raise SynapseError(400, "Room version mismatch") raise SynapseError(400, "Room version mismatch")
# filter out any events we have already seen
seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
for s in seen_remotes:
event_map.pop(s, None)
# persist the auth chain and state events. # persist the auth chain and state events.
# #
# any invalid events here will be marked as rejected, and we'll carry on. # any invalid events here will be marked as rejected, and we'll carry on.
@ -455,7 +448,9 @@ class FederationEventHandler:
# signatures right now doesn't mean that we will *never* be able to, so it # signatures right now doesn't mean that we will *never* be able to, so it
# is premature to reject them. # is premature to reject them.
# #
await self._auth_and_persist_outliers(room_id, event_map.values()) await self._auth_and_persist_outliers(
room_id, itertools.chain(auth_events, state)
)
# and now persist the join event itself. # and now persist the join event itself.
logger.info("Peristing join-via-remote %s", event) logger.info("Peristing join-via-remote %s", event)
@ -1245,6 +1240,16 @@ class FederationEventHandler:
""" """
event_map = {event.event_id: event for event in events} event_map = {event.event_id: event for event in events}
# filter out any events we have already seen. This might happen because
# the events were eagerly pushed to us (eg, during a room join), or because
# another thread has raced against us since we decided to request the event.
#
# This is just an optimisation, so it doesn't need to be watertight - the event
# persister does another round of deduplication.
seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
for s in seen_remotes:
event_map.pop(s, None)
# XXX: it might be possible to kick this process off in parallel with fetching # XXX: it might be possible to kick this process off in parallel with fetching
# the events. # the events.
while event_map: while event_map:
@ -1717,31 +1722,22 @@ class FederationEventHandler:
event_id: the event for which we are lacking auth events event_id: the event for which we are lacking auth events
""" """
try: try:
remote_event_map = { remote_events = await self._federation_client.get_event_auth(
e.event_id: e destination, room_id, event_id
for e in await self._federation_client.get_event_auth( )
destination, room_id, event_id
)
}
except RequestSendFailed as e1: except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the # The other side isn't around or doesn't implement the
# endpoint, so lets just bail out. # endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e1) logger.info("Failed to get event auth from remote: %s", e1)
return return
logger.info("/event_auth returned %i events", len(remote_event_map)) logger.info("/event_auth returned %i events", len(remote_events))
# `event` may be returned, but we should not yet process it. # `event` may be returned, but we should not yet process it.
remote_event_map.pop(event_id, None) remote_auth_events = (e for e in remote_events if e.event_id != event_id)
# nor should we reprocess any events we have already seen. await self._auth_and_persist_outliers(room_id, remote_auth_events)
seen_remotes = await self._store.have_seen_events(
room_id, remote_event_map.keys()
)
for s in seen_remotes:
remote_event_map.pop(s, None)
await self._auth_and_persist_outliers(room_id, remote_event_map.values())
async def _update_context_for_auth_events( async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]

View File

@ -550,10 +550,11 @@ class EventCreationHandler:
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"] room_version_id = event_dict["content"]["room_version"]
room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version_obj: if not maybe_room_version_obj:
# this can happen if support is withdrawn for a room version # this can happen if support is withdrawn for a room version
raise UnsupportedRoomVersionError(room_version_id) raise UnsupportedRoomVersionError(room_version_id)
room_version_obj = maybe_room_version_obj
else: else:
try: try:
room_version_obj = await self.store.get_room_version( room_version_obj = await self.store.get_room_version(
@ -1145,12 +1146,13 @@ class EventCreationHandler:
room_version_id = event.content.get( room_version_id = event.content.get(
"room_version", RoomVersions.V1.identifier "room_version", RoomVersions.V1.identifier
) )
room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version_obj: if not maybe_room_version_obj:
raise UnsupportedRoomVersionError( raise UnsupportedRoomVersionError(
"Attempt to create a room with unsupported room version %s" "Attempt to create a room with unsupported room version %s"
% (room_version_id,) % (room_version_id,)
) )
room_version_obj = maybe_room_version_obj
else: else:
room_version_obj = await self.store.get_room_version(event.room_id) room_version_obj = await self.store.get_room_version(event.room_id)

View File

@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
Returns: Returns:
dict: `user_id` -> `UserPresenceState` dict: `user_id` -> `UserPresenceState`
""" """
states = { states = {}
user_id: self.user_to_current_state.get(user_id, None) missing = []
for user_id in user_ids for user_id in user_ids:
} state = self.user_to_current_state.get(user_id, None)
if state:
states[user_id] = state
else:
missing.append(user_id)
missing = [user_id for user_id, state in states.items() if not state]
if missing: if missing:
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = await self.store.get_presence_for_users(missing) res = await self.store.get_presence_for_users(missing)
states.update(res) states.update(res)
missing = [user_id for user_id, state in states.items() if not state] for user_id in missing:
if missing: # if user has no state in database, create the state
new = { if not res.get(user_id, None):
user_id: UserPresenceState.default(user_id) for user_id in missing new_state = UserPresenceState.default(user_id)
} states[user_id] = new_state
states.update(new) self.user_to_current_state[user_id] = new_state
self.user_to_current_state.update(new)
return states return states

View File

@ -320,12 +320,12 @@ class RegistrationHandler:
if fail_count > 10: if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID") raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = await self.store.generate_user_id() generated_localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(generated_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id) self.check_user_id_not_appservice_exclusive(user_id)
if generate_display_name: if generate_display_name:
default_display_name = localpart default_display_name = generated_localpart
try: try:
await self.register_with_store( await self.register_with_store(
user_id=user_id, user_id=user_id,

View File

@ -82,7 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()
self.member_linearizer: Linearizer = Linearizer(name="member") self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter") self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
@ -507,7 +507,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
then = self.clock.time_msec() then = self.clock.time_msec()
with (await self.member_limiter.queue(as_id)): # We first linearise by the application service (to try to limit concurrent joins
# by application services), and then by room ID.
with (await self.member_as_limiter.queue(as_id)):
diff = self.clock.time_msec() - then diff = self.clock.time_msec() - then
if diff > 80 * 1000: if diff > 80 * 1000:

View File

@ -14,8 +14,9 @@
import itertools import itertools
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
import attr
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -32,6 +33,20 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _SearchResult:
# The count of results.
count: int
# A mapping of event ID to the rank of that event.
rank_map: Dict[str, int]
# A list of the resulting events.
allowed_events: List[EventBase]
# A map of room ID to results.
room_groups: Dict[str, JsonDict]
# A set of event IDs to highlight.
highlights: Set[str]
class SearchHandler: class SearchHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -100,7 +115,7 @@ class SearchHandler:
"""Performs a full text search for a user. """Performs a full text search for a user.
Args: Args:
user user: The user performing the search.
content: Search parameters content: Search parameters
batch: The next_batch parameter. Used for pagination. batch: The next_batch parameter. Used for pagination.
@ -156,6 +171,8 @@ class SearchHandler:
# Include context around each event? # Include context around each event?
event_context = room_cat.get("event_context", None) event_context = room_cat.get("event_context", None)
before_limit = after_limit = None
include_profile = False
# Group results together? May allow clients to paginate within a # Group results together? May allow clients to paginate within a
# group # group
@ -182,6 +199,73 @@ class SearchHandler:
% (set(group_keys) - {"room_id", "sender"},), % (set(group_keys) - {"room_id", "sender"},),
) )
return await self._search(
user,
batch_group,
batch_group_key,
batch_token,
search_term,
keys,
filter_dict,
order_by,
include_state,
group_keys,
event_context,
before_limit,
after_limit,
include_profile,
)
async def _search(
self,
user: UserID,
batch_group: Optional[str],
batch_group_key: Optional[str],
batch_token: Optional[str],
search_term: str,
keys: List[str],
filter_dict: JsonDict,
order_by: str,
include_state: bool,
group_keys: List[str],
event_context: Optional[bool],
before_limit: Optional[int],
after_limit: Optional[int],
include_profile: bool,
) -> JsonDict:
"""Performs a full text search for a user.
Args:
user: The user performing the search.
batch_group: Pagination information.
batch_group_key: Pagination information.
batch_token: Pagination information.
search_term: Search term to search for
keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
filter_dict: The JSON to build a filter out of.
order_by: How to order the results. Valid values ore "rank" and "recent".
include_state: True if the state of the room at each result should
be included.
group_keys: A list of ways to group the results. Valid values are
"room_id" and "sender".
event_context: True to include contextual events around results.
before_limit:
The number of events before a result to include as context.
Only used if event_context is True.
after_limit:
The number of events after a result to include as context.
Only used if event_context is True.
include_profile: True if historical profile information should be
included in the event context.
Only used if event_context is True.
Returns:
dict to be returned to the client with results of search
"""
search_filter = Filter(self.hs, filter_dict) search_filter = Filter(self.hs, filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
@ -216,209 +300,57 @@ class SearchHandler:
} }
} }
rank_map = {} # event_id -> rank of event sender_group: Optional[Dict[str, JsonDict]]
allowed_events = []
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
sender_group: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
highlights = set()
count = None
if order_by == "rank": if order_by == "rank":
search_result = await self.store.search_msgs(room_ids, search_term, keys) search_result, sender_group = await self._search_by_rank(
user, room_ids, search_term, keys, search_filter
count = search_result["count"]
if search_result["highlights"]:
highlights.update(search_result["highlights"])
results = search_result["results"]
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
) )
# Unused return values for rank search.
events.sort(key=lambda e: -rank_map[e.event_id]) global_next_batch = None
allowed_events = events[: search_filter.limit]
for e in allowed_events:
rm = room_groups.setdefault(
e.room_id, {"results": [], "order": rank_map[e.event_id]}
)
rm["results"].append(e.event_id)
s = sender_group.setdefault(
e.sender, {"results": [], "order": rank_map[e.event_id]}
)
s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
room_events: List[EventBase] = [] search_result, global_next_batch = await self._search_by_recent(
i = 0 user,
room_ids,
pagination_token = batch_token search_term,
keys,
# We keep looping and we keep filtering until we reach the limit search_filter,
# or we run out of things. batch_group,
# But only go around 5 times since otherwise synapse will be sad. batch_group_key,
while len(room_events) < search_filter.limit and i < 5: batch_token,
i += 1 )
search_result = await self.store.search_rooms( # Unused return values for recent search.
room_ids, sender_group = None
search_term,
keys,
search_filter.limit * 2,
pagination_token=pagination_token,
)
if search_result["highlights"]:
highlights.update(search_result["highlights"])
count = search_result["count"]
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = await search_filter.filter(
[r["event"] for r in results]
)
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[: search_filter.limit]
if len(results) < search_filter.limit * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]
for event in room_events:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]
# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64(
(
"%s\n%s\n%s"
% (batch_group, batch_group_key, pagination_token)
).encode("ascii")
)
else:
global_next_batch = encode_base64(
("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
)
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64(
("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
"ascii"
)
)
allowed_events.extend(room_events)
else: else:
# We should never get here due to the guard earlier. # We should never get here due to the guard earlier.
raise NotImplementedError() raise NotImplementedError()
logger.info("Found %d events to return", len(allowed_events)) logger.info("Found %d events to return", len(search_result.allowed_events))
# If client has asked for "context" for each event (i.e. some surrounding # If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that # events and state), fetch that
if event_context is not None: if event_context is not None:
now_token = self.hs.get_event_sources().get_current_token() # Note that before and after limit must be set in this case.
assert before_limit is not None
assert after_limit is not None
contexts = {} contexts = await self._calculate_event_contexts(
for event in allowed_events: user,
res = await self.store.get_events_around( search_result.allowed_events,
event.room_id, event.event_id, before_limit, after_limit before_limit,
) after_limit,
include_profile,
logger.info( )
"Context for search returned %d and %d events",
len(res.events_before),
len(res.events_after),
)
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
)
context = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
"room_key", res.start
).to_string(self.store),
"end": await now_token.copy_and_replace(
"room_key", res.end
).to_string(self.store),
}
if include_profile:
senders = {
ev.sender
for ev in itertools.chain(events_before, [event], events_after)
}
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id
state_filter = StateFilter.from_types(
[(EventTypes.Member, sender) for sender in senders]
)
state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
}
for s in state.values()
if s.type == EventTypes.Member and s.state_key in senders
}
contexts[event.event_id] = context
else: else:
contexts = {} contexts = {}
# TODO: Add a limit # TODO: Add a limit
time_now = self.clock.time_msec() state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = None aggregations = None
if self._msc3666_enabled: if self._msc3666_enabled:
@ -432,11 +364,16 @@ class SearchHandler:
for context in contexts.values() for context in contexts.values()
), ),
# The returned events. # The returned events.
allowed_events, search_result.allowed_events,
), ),
user.to_string(), user.to_string(),
) )
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
time_now = self.clock.time_msec()
for context in contexts.values(): for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events( context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
@ -445,44 +382,33 @@ class SearchHandler:
context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
) )
state_results = {} results = [
if include_state: {
for room_id in {e.room_id for e in allowed_events}: "rank": search_result.rank_map[e.event_id],
state = await self.state_handler.get_current_state(room_id) "result": self._event_serializer.serialize_event(
state_results[room_id] = list(state.values()) e, time_now, bundle_aggregations=aggregations
),
"context": contexts.get(e.event_id, {}),
}
for e in search_result.allowed_events
]
# We're now about to serialize the events. We should not make any rooms_cat_res: JsonDict = {
# blocking calls after this. Otherwise the 'age' will be wrong
results = []
for e in allowed_events:
results.append(
{
"rank": rank_map[e.event_id],
"result": self._event_serializer.serialize_event(
e, time_now, bundle_aggregations=aggregations
),
"context": contexts.get(e.event_id, {}),
}
)
rooms_cat_res = {
"results": results, "results": results,
"count": count, "count": search_result.count,
"highlights": list(highlights), "highlights": list(search_result.highlights),
} }
if state_results: if state_results:
s = {} rooms_cat_res["state"] = {
for room_id, state_events in state_results.items(): room_id: self._event_serializer.serialize_events(state_events, time_now)
s[room_id] = self._event_serializer.serialize_events( for room_id, state_events in state_results.items()
state_events, time_now }
)
rooms_cat_res["state"] = s if search_result.room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})[
if room_groups and "room_id" in group_keys: "room_id"
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups ] = search_result.room_groups
if sender_group and "sender" in group_keys: if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
@ -491,3 +417,282 @@ class SearchHandler:
rooms_cat_res["next_batch"] = global_next_batch rooms_cat_res["next_batch"] = global_next_batch
return {"search_categories": {"room_events": rooms_cat_res}} return {"search_categories": {"room_events": rooms_cat_res}}
async def _search_by_rank(
self,
user: UserID,
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
search_filter: Filter,
) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
"""
Performs a full text search for a user ordering by rank.
Args:
user: The user performing the search.
room_ids: List of room ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
search_filter: The event filter to use.
Returns:
A tuple of:
The search results.
A map of sender ID to results.
"""
rank_map = {} # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
sender_group: Dict[str, JsonDict] = {}
search_result = await self.store.search_msgs(room_ids, search_term, keys)
if search_result["highlights"]:
highlights = search_result["highlights"]
else:
highlights = set()
results = search_result["results"]
# event_id -> rank of event
rank_map = {r["event"].event_id: r["rank"] for r in results}
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[: search_filter.limit]
for e in allowed_events:
rm = room_groups.setdefault(
e.room_id, {"results": [], "order": rank_map[e.event_id]}
)
rm["results"].append(e.event_id)
s = sender_group.setdefault(
e.sender, {"results": [], "order": rank_map[e.event_id]}
)
s["results"].append(e.event_id)
return (
_SearchResult(
search_result["count"],
rank_map,
allowed_events,
room_groups,
highlights,
),
sender_group,
)
async def _search_by_recent(
self,
user: UserID,
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
search_filter: Filter,
batch_group: Optional[str],
batch_group_key: Optional[str],
batch_token: Optional[str],
) -> Tuple[_SearchResult, Optional[str]]:
"""
Performs a full text search for a user ordering by recent.
Args:
user: The user performing the search.
room_ids: List of room ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
search_filter: The event filter to use.
batch_group: Pagination information.
batch_group_key: Pagination information.
batch_token: Pagination information.
Returns:
A tuple of:
The search results.
Optionally, a pagination token.
"""
rank_map = {} # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
highlights = set()
room_events: List[EventBase] = []
i = 0
pagination_token = batch_token
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit and i < 5:
i += 1
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
search_filter.limit * 2,
pagination_token=pagination_token,
)
if search_result["highlights"]:
highlights.update(search_result["highlights"])
count = search_result["count"]
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[: search_filter.limit]
if len(results) < search_filter.limit * 2:
break
else:
pagination_token = results[-1]["pagination_token"]
for event in room_events:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]
# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64(
(
"%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
).encode("ascii")
)
else:
global_next_batch = encode_base64(
("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
)
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64(
("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
"ascii"
)
)
return (
_SearchResult(count, rank_map, room_events, room_groups, highlights),
global_next_batch,
)
async def _calculate_event_contexts(
self,
user: UserID,
allowed_events: List[EventBase],
before_limit: int,
after_limit: int,
include_profile: bool,
) -> Dict[str, JsonDict]:
"""
Calculates the contextual events for any search results.
Args:
user: The user performing the search.
allowed_events: The search results.
before_limit:
The number of events before a result to include as context.
after_limit:
The number of events after a result to include as context.
include_profile: True if historical profile information should be
included in the event context.
Returns:
A map of event ID to contextual information.
"""
now_token = self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
logger.info(
"Context for search returned %d and %d events",
len(res.events_before),
len(res.events_after),
)
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
)
context: JsonDict = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
"room_key", res.start
).to_string(self.store),
"end": await now_token.copy_and_replace("room_key", res.end).to_string(
self.store
),
}
if include_profile:
senders = {
ev.sender
for ev in itertools.chain(events_before, [event], events_after)
}
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id
state_filter = StateFilter.from_types(
[(EventTypes.Member, sender) for sender in senders]
)
state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
}
for s in state.values()
if s.type == EventTypes.Member and s.state_key in senders
}
contexts[event.event_id] = context
return contexts

View File

@ -1289,23 +1289,54 @@ class SyncHandler:
# room with by looking at all users that have left a room plus users # room with by looking at all users that have left a room plus users
# that were in a room we've left. # that were in a room we've left.
users_who_share_room = await self.store.get_users_who_share_room_with_user( users_that_have_changed = set()
user_id
joined_rooms = sync_result_builder.joined_room_ids
# Step 1a, check for changes in devices of users we share a room
# with
#
# We do this in two different ways depending on what we have cached.
# If we already have a list of all the user that have changed since
# the last sync then it's likely more efficient to compare the rooms
# they're in with the rooms the syncing user is in.
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
changed_users = self.store.get_cached_device_list_changes(
since_token.device_list_key
) )
if changed_users is not None:
result = await self.store.get_rooms_for_users_with_stream_ordering(
changed_users
)
# Always tell the user about their own devices. We check as the user for changed_user_id, entries in result.items():
# ID is almost certainly already included (unless they're not in any # Check if the changed user shares any rooms with the user,
# rooms) and taking a copy of the set is relatively expensive. # or if the changed user is the syncing user (as we always
if user_id not in users_who_share_room: # want to include device list updates of their own devices).
users_who_share_room = set(users_who_share_room) if user_id == changed_user_id or any(
users_who_share_room.add(user_id) e.room_id in joined_rooms for e in entries
):
users_that_have_changed.add(changed_user_id)
else:
users_who_share_room = (
await self.store.get_users_who_share_room_with_user(user_id)
)
tracked_users = users_who_share_room # Always tell the user about their own devices. We check as the user
# ID is almost certainly already included (unless they're not in any
# rooms) and taking a copy of the set is relatively expensive.
if user_id not in users_who_share_room:
users_who_share_room = set(users_who_share_room)
users_who_share_room.add(user_id)
# Step 1a, check for changes in devices of users we share a room with tracked_users = users_who_share_room
users_that_have_changed = await self.store.get_users_whose_devices_changed( users_that_have_changed = (
since_token.device_list_key, tracked_users await self.store.get_users_whose_devices_changed(
) since_token.device_list_key, tracked_users
)
)
# Step 1b, check for newly joined rooms # Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
@ -1329,7 +1360,14 @@ class SyncHandler:
newly_left_users.update(left_users) newly_left_users.update(left_users)
# Remove any users that we still share a room with. # Remove any users that we still share a room with.
newly_left_users -= users_who_share_room left_users_rooms = (
await self.store.get_rooms_for_users_with_stream_ordering(
newly_left_users
)
)
for user_id, entries in left_users_rooms.items():
if any(e.room_id in joined_rooms for e in entries):
newly_left_users.discard(user_id)
return DeviceLists(changed=users_that_have_changed, left=newly_left_users) return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else: else:

View File

@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
) )
return body return body
@overload
async def get_json( async def get_json(
self, self,
destination: str, destination: str,
@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None, timeout: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]: ) -> Union[JsonDict, list]:
...
@overload
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
max_response_size: Optional[int] = ...,
) -> T:
...
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
max_response_size: Optional[int] = None,
):
"""GETs some json from the given host homeserver and path """GETs some json from the given host homeserver and path
Args: Args:
@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3. the request. Workaround for #3622 in Synapse <= v0.99.3.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
max_response_size: The maximum size to read from the response. If None,
uses the default.
Returns: Returns:
Succeeds when we get a 2xx HTTP response. The Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
if parser is None:
parser = JsonParser()
body = await _handle_response( body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=parser,
max_response_size=max_response_size,
) )
return body return body

View File

@ -1,163 +0,0 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from typing import Any, Dict, Generator, Optional, Tuple
from constantly import NamedConstant, Names
from synapse.config._base import ConfigError
class DrainType(Names):
CONSOLE = NamedConstant()
CONSOLE_JSON = NamedConstant()
CONSOLE_JSON_TERSE = NamedConstant()
FILE = NamedConstant()
FILE_JSON = NamedConstant()
NETWORK_JSON_TERSE = NamedConstant()
DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
def parse_drain_configs(
drains: dict,
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
"""
Parse the drain configurations.
Args:
drains (dict): A list of drain configurations.
Yields:
dict instances representing a logging handler.
Raises:
ConfigError: If any of the drain configuration items are invalid.
"""
for name, config in drains.items():
if "type" not in config:
raise ConfigError("Logging drains require a 'type' key.")
try:
logging_type = DrainType.lookupByName(config["type"].upper())
except ValueError:
raise ConfigError(
"%s is not a known logging drain type." % (config["type"],)
)
# Either use the default formatter or the tersejson one.
if logging_type in (
DrainType.CONSOLE_JSON,
DrainType.FILE_JSON,
):
formatter: Optional[str] = "json"
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
DrainType.NETWORK_JSON_TERSE,
):
formatter = "tersejson"
else:
# A formatter of None implies using the default formatter.
formatter = None
if logging_type in [
DrainType.CONSOLE,
DrainType.CONSOLE_JSON,
DrainType.CONSOLE_JSON_TERSE,
]:
location = config.get("location")
if location is None or location not in ["stdout", "stderr"]:
raise ConfigError(
(
"The %s drain needs the 'location' key set to "
"either 'stdout' or 'stderr'."
)
% (logging_type,)
)
yield name, {
"class": "logging.StreamHandler",
"formatter": formatter,
"stream": "ext://sys." + location,
}
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
if "location" not in config:
raise ConfigError(
"The %s drain needs the 'location' key set." % (logging_type,)
)
location = config.get("location")
if os.path.abspath(location) != location:
raise ConfigError(
"File paths need to be absolute, '%s' is a relative path"
% (location,)
)
yield name, {
"class": "logging.FileHandler",
"formatter": formatter,
"filename": location,
}
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
host = config.get("host")
port = config.get("port")
maximum_buffer = config.get("maximum_buffer", 1000)
yield name, {
"class": "synapse.logging.RemoteHandler",
"formatter": formatter,
"host": host,
"port": port,
"maximum_buffer": maximum_buffer,
}
else:
raise ConfigError(
"The %s drain type is currently not implemented."
% (config["type"].upper(),)
)
def setup_structured_logging(
log_config: dict,
) -> dict:
"""
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers.
"""
if "drains" not in log_config:
raise ConfigError("The logging configuration requires a list of drains.")
new_config = {
"version": 1,
"formatters": {
"json": {"class": "synapse.logging.JsonFormatter"},
"tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
},
"handlers": {},
"loggers": log_config.get("loggers", DEFAULT_LOGGERS),
"root": {"handlers": []},
}
for handler_name, handler in parse_drain_configs(log_config["drains"]):
new_config["handlers"][handler_name] = handler
# Add each handler to the root logger.
new_config["root"]["handlers"].append(handler_name)
return new_config

View File

@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
from synapse.handlers.auth import ( from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK, CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK, CHECK_AUTH_CALLBACK,
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK, IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK, ON_LOGGED_OUT_CALLBACK,
@ -317,6 +318,9 @@ class ModuleApi:
get_username_for_registration: Optional[ get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None, ] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None: ) -> None:
"""Registers callbacks for password auth provider capabilities. """Registers callbacks for password auth provider capabilities.
@ -328,6 +332,7 @@ class ModuleApi:
is_3pid_allowed=is_3pid_allowed, is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers, auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration, get_username_for_registration=get_username_for_registration,
get_displayname_for_registration=get_displayname_for_registration,
) )
def register_background_update_controller_callbacks( def register_background_update_controller_callbacks(
@ -648,7 +653,11 @@ class ModuleApi:
Added in Synapse v1.9.0. Added in Synapse v1.9.0.
Args: Args:
auth_provider: identifier for the remote auth provider auth_provider: identifier for the remote auth provider, see `sso` and
`oidc_providers` in the homeserver configuration.
Note that no error is raised if the provided value is not in the
homeserver configuration.
external_id: id on that system external_id: id on that system
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
""" """

View File

@ -138,7 +138,7 @@ class _NotifierUserStream:
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred notify_deferred = self.notify_deferred
log_kv( log_kv(
{ {
@ -153,7 +153,7 @@ class _NotifierUserStream:
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) notify_deferred.callback(self.current_token)
def remove(self, notifier: "Notifier") -> None: def remove(self, notifier: "Notifier") -> None:
"""Remove this listener from all the indexes in the Notifier """Remove this listener from all the indexes in the Notifier

View File

@ -130,7 +130,9 @@ def make_base_prepend_rules(
return rules return rules
BASE_APPEND_CONTENT_RULES = [ # We have to annotate these types, otherwise mypy infers them as
# `List[Dict[str, Sequence[Collection[str]]]]`.
BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
{ {
"rule_id": "global/content/.m.rule.contains_user_name", "rule_id": "global/content/.m.rule.contains_user_name",
"conditions": [ "conditions": [
@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [
] ]
BASE_PREPEND_OVERRIDE_RULES = [ BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
{ {
"rule_id": "global/override/.m.rule.master", "rule_id": "global/override/.m.rule.master",
"enabled": False, "enabled": False,
@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
] ]
BASE_APPEND_OVERRIDE_RULES = [ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
{ {
"rule_id": "global/override/.m.rule.suppress_notices", "rule_id": "global/override/.m.rule.suppress_notices",
"conditions": [ "conditions": [
@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [
] ]
BASE_APPEND_UNDERRIDE_RULES = [ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
{ {
"rule_id": "global/underride/.m.rule.call", "rule_id": "global/underride/.m.rule.call",
"conditions": [ "conditions": [

View File

@ -114,6 +114,7 @@ class HttpPusher(Pusher):
self.data_minus_url = {} self.data_minus_url = {}
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url["url"] del self.data_minus_url["url"]
self.badge_count_last_call: Optional[int] = None
def on_started(self, should_check_for_notifs: bool) -> None: def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started. """Called when this pusher has been started.
@ -141,7 +142,9 @@ class HttpPusher(Pusher):
self.user_id, self.user_id,
group_by_room=self._group_unread_count_by_room, group_by_room=self._group_unread_count_by_room,
) )
await self._send_badge(badge) if self.badge_count_last_call is None or self.badge_count_last_call != badge:
self.badge_count_last_call = badge
await self._send_badge(badge)
def on_timer(self) -> None: def on_timer(self) -> None:
self._start_processing() self._start_processing()
@ -327,7 +330,7 @@ class HttpPusher(Pusher):
# This was checked in the __init__, but mypy doesn't seem to know that. # This was checked in the __init__, but mypy doesn't seem to know that.
assert self.data is not None assert self.data is not None
if self.data.get("format") == "event_id_only": if self.data.get("format") == "event_id_only":
d = { d: Dict[str, Any] = {
"notification": { "notification": {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -407,6 +410,8 @@ class HttpPusher(Pusher):
rejected = [] rejected = []
if "rejected" in resp: if "rejected" in resp:
rejected = resp["rejected"] rejected = resp["rejected"]
else:
self.badge_count_last_call = badge
return rejected return rejected
async def _send_badge(self, badge: int) -> None: async def _send_badge(self, badge: int) -> None:

View File

@ -87,7 +87,8 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl` # We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches. # with the latest security patches.
"cryptography>=3.4.7", "cryptography>=3.4.7",
"ijson>=3.1", # ijson 3.1.4 fixes a bug with "." in property names
"ijson>=3.1.4",
"matrix-common~=1.1.0", "matrix-common~=1.1.0",
] ]

View File

@ -5,8 +5,6 @@
"endpoints": [ "endpoints": [
{ {
"schemes": [ "schemes": [
"https://twitter.com/*/status/*",
"https://*.twitter.com/*/status/*",
"https://twitter.com/*/moments/*", "https://twitter.com/*/moments/*",
"https://*.twitter.com/*/moments/*" "https://*.twitter.com/*/moments/*"
], ],
@ -14,4 +12,4 @@
} }
] ]
} }
] ]

View File

@ -886,7 +886,9 @@ class WhoamiRestServlet(RestServlet):
response = { response = {
"user_id": requester.user.to_string(), "user_id": requester.user.to_string(),
# MSC: https://github.com/matrix-org/matrix-doc/pull/3069 # MSC: https://github.com/matrix-org/matrix-doc/pull/3069
# Entered spec in Matrix 1.2
"org.matrix.msc3069.is_guest": bool(requester.is_guest), "org.matrix.msc3069.is_guest": bool(requester.is_guest),
"is_guest": bool(requester.is_guest),
} }
# Appservices and similar accounts do not have device IDs # Appservices and similar accounts do not have device IDs

View File

@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA: if stagetype == LoginType.RECAPTCHA:
html = self.recaptcha_template.render( html = self.recaptcha_template.render(
session=session, session=session,
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA), % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.captcha.recaptcha_public_key, sitekey=self.hs.config.captcha.recaptcha_public_key,
) )
@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet):
self.hs.config.server.public_baseurl, self.hs.config.server.public_baseurl,
self.hs.config.consent.user_consent_version, self.hs.config.consent.user_consent_version,
), ),
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS), % (CLIENT_API_PREFIX, LoginType.TERMS),
) )
@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
# Authentication failed, let user try again # Authentication failed, let user try again
html = self.recaptcha_template.render( html = self.recaptcha_template.render(
session=session, session=session,
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA), % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.captcha.recaptcha_public_key, sitekey=self.hs.config.captcha.recaptcha_public_key,
error=e.msg, error=e.msg,
@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet):
self.hs.config.server.public_baseurl, self.hs.config.server.public_baseurl,
self.hs.config.consent.user_consent_version, self.hs.config.consent.user_consent_version,
), ),
myurl="%s/r0/auth/%s/fallback/web" myurl="%s/v3/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS), % (CLIENT_API_PREFIX, LoginType.TERMS),
error=e.msg, error=e.msg,
) )

View File

@ -72,20 +72,6 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities" "org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES ] = MSC3244_CAPABILITIES
# Must be removed in later versions.
# Is only included for migration.
# Also the parts in `synapse/config/experimental.py`.
if self.config.experimental.msc3283_enabled:
response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
"enabled": self.config.registration.enable_set_displayname
}
response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
"enabled": self.config.registration.enable_set_avatar_url
}
response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
"enabled": self.config.registration.enable_3pid_changes
}
if self.config.experimental.msc3440_enabled: if self.config.experimental.msc3440_enabled:
response["capabilities"]["io.element.thread"] = {"enabled": True} response["capabilities"]["io.element.thread"] = {"enabled": True}

View File

@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
session_id session_id
) )
display_name = await (
self.password_auth_provider.get_displayname_for_registration(
auth_result, params
)
)
registered_user_id = await self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password_hash=password_hash, password_hash=password_hash,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
threepid=threepid, threepid=threepid,
default_display_name=display_name,
address=client_addr, address=client_addr,
user_agent_ips=entries, user_agent_ips=entries,
) )

View File

@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet):
"r0.5.0", "r0.5.0",
"r0.6.0", "r0.6.0",
"r0.6.1", "r0.6.1",
"v1.1",
"v1.2",
], ],
# as per MSC1497: # as per MSC1497:
"unstable_features": { "unstable_features": {

View File

@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource):
url, url,
output_stream=output_stream, output_stream=output_stream,
max_size=self.max_spider_size, max_size=self.max_spider_size,
headers={"Accept-Language": self.url_preview_accept_language}, headers={
b"Accept-Language": self.url_preview_accept_language,
# Use a custom user agent for the preview because some sites will only return
# Open Graph metadata to crawler user agents. Omit the Synapse version
# string to avoid leaking information.
b"User-Agent": [
"Synapse (bot; +https://github.com/matrix-org/synapse)"
],
},
is_allowed_content_type=_is_previewable, is_allowed_content_type=_is_previewable,
) )
except SynapseError: except SynapseError:

View File

@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices device["device_id"]: db_to_json(device["content"]) for device in devices
} }
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[Set[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str] self, from_key: int, user_ids: Iterable[str]
) -> Set[str]: ) -> Set[str]:

View File

@ -975,6 +975,17 @@ class PersistEventsStore:
to_delete = delta_state.to_delete to_delete = delta_state.to_delete
to_insert = delta_state.to_insert to_insert = delta_state.to_insert
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invalidate the caches for all
# those users.
members_changed = {
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
}
if delta_state.no_longer_in_room: if delta_state.no_longer_in_room:
# Server is no longer in the room so we delete the room from # Server is no longer in the room so we delete the room from
# current_state_events, being careful we've already updated the # current_state_events, being careful we've already updated the
@ -993,6 +1004,11 @@ class PersistEventsStore:
""" """
txn.execute(sql, (stream_id, self._instance_name, room_id)) txn.execute(sql, (stream_id, self._instance_name, room_id))
# We also want to invalidate the membership caches for users
# that were in the room.
users_in_room = self.store.get_users_in_room_txn(txn, room_id)
members_changed.update(users_in_room)
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="current_state_events", table="current_state_events",
@ -1102,17 +1118,6 @@ class PersistEventsStore:
# Invalidate the various caches # Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invalidate the caches for all
# those users.
members_changed = {
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
}
for member in members_changed: for member in members_changed:
txn.call_after( txn.call_after(
self.store.get_rooms_for_user_with_stream_ordering.invalidate, self.store.get_rooms_for_user_with_stream_ordering.invalidate,

View File

@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events. Otherwise, allow_rejected: If True, return rejected events. Otherwise,
omits rejeted events from the response. omits rejected events from the response.
Returns: Returns:
A mapping from event_id to event. A mapping from event_id to event.
@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
forward_edge_query = """ forward_edge_query = """
SELECT 1 FROM event_edges SELECT 1 FROM event_edges
/* Check to make sure the event referencing our event in question is not rejected */ /* Check to make sure the event referencing our event in question is not rejected */
LEFT JOIN rejections ON event_edges.event_id == rejections.event_id LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
WHERE WHERE
event_edges.room_id = ? event_edges.room_id = ?
AND event_edges.prev_event_id = ? AND event_edges.prev_event_id = ?

View File

@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Used by `PresenceStore._get_active_presence()` # Used by `PresenceStore._get_active_presence()`
@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._presence_id_gen: AbstractStreamIdGenerator
self._can_persist_presence = ( self._can_persist_presence = (
hs.get_instance_name() in hs.config.worker.writers.presence self._instance_name in hs.config.worker.writers.presence
) )
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):
@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token() return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states
) -> None:
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after( txn.call_after(
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_presence_updates_txn(txn): def get_all_presence_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """ sql = """
SELECT stream_id, user_id, state, last_active_ts, SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, last_federation_update_ts, last_user_sync_ts,
status_msg, status_msg, currently_active
currently_active
FROM presence_stream FROM presence_stream
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn] updates = cast(
List[Tuple[int, list]],
[(row[0], row[1:]) for row in txn],
)
upper_bound = current_id upper_bound = current_id
limited = False limited = False
@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
) )
@cached() @cached()
def _get_presence_for_user(self, user_id): def _get_presence_for_user(self, user_id: str) -> None:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
list_name="user_ids", list_name="user_ids",
num_args=1, num_args=1,
) )
async def get_presence_for_users(self, user_ids): async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Dict[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",
@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
True if the user should have full presence sent to them, False otherwise. True if the user should have full presence sent to them, False otherwise.
""" """
def _should_user_receive_full_presence_with_token_txn(txn): def _should_user_receive_full_presence_with_token_txn(
txn: LoggingTransaction,
) -> bool:
sql = """ sql = """
SELECT 1 FROM users_to_send_full_presence_to SELECT 1 FROM users_to_send_full_presence_to
WHERE user_id = ? WHERE user_id = ?
@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
_should_user_receive_full_presence_with_token_txn, _should_user_receive_full_presence_with_token_txn,
) )
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
"""Adds to the list of users who should receive a full snapshot of presence """Adds to the list of users who should receive a full snapshot of presence
upon their next sync. upon their next sync.
@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return users_to_state return users_to_state
def get_current_presence_token(self): def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def _get_active_presence(self, db_conn: Connection): def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register """Fetch non-offline presence from the database so that we can register
the appropriate time outs. the appropriate time outs.
""" """
@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
def take_presence_startup_info(self): def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup active_on_startup = self._presence_on_startup
self._presence_on_startup = None self._presence_on_startup = []
return active_on_startup return active_on_startup
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
if stream_name == PresenceStream.NAME: if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token) self._presence_id_gen.advance(instance_name, token)
for row in rows: for row in rows:

View File

@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, List, Set, Tuple from typing import Any, List, Set, Tuple, cast
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
) )
def _purge_history_txn( def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool self,
txn: LoggingTransaction,
room_id: str,
token: RoomStreamToken,
delete_local_events: bool,
) -> Set[int]: ) -> Set[int]:
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""", """,
(room_id,), (room_id,),
) )
(min_depth,) = txn.fetchone() (min_depth,) = cast(Tuple[int], txn.fetchone())
logger.info("[purge] updating room_depth to %d", min_depth) logger.info("[purge] updating room_depth to %d", min_depth)
@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"purge_room", self._purge_room_txn, room_id "purge_room", self._purge_room_txn, room_id
) )
def _purge_room_txn(self, txn, room_id: str) -> List[int]: def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before
# we delete that information. # we delete that information.
txn.execute( txn.execute(

View File

@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> None: ) -> None:
"""Record a mapping from an external user id to a mxid """Record a mapping from an external user id to a mxid
See notes in _record_user_external_id_txn about what constitutes valid data.
Args: Args:
auth_provider: identifier for the remote auth provider auth_provider: identifier for the remote auth provider
external_id: id on that system external_id: id on that system
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
Raises: Raises:
ExternalIDReuseException if the new external_id could not be mapped. ExternalIDReuseException if the new external_id could not be mapped.
""" """
@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
external_id: str, external_id: str,
user_id: str, user_id: str,
) -> None: ) -> None:
"""
Record a mapping from an external user id to a mxid.
Note that the auth provider IDs (and the external IDs) are not validated
against configured IdPs as Synapse does not know its relationship to
external systems. For example, it might be useful to pre-configure users
before enabling a new IdP or an IdP might be temporarily offline, but
still valid.
Args:
txn: The database transaction.
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Replace mappings from external user ids to a mxid in a single transaction. """Replace mappings from external user ids to a mxid in a single transaction.
All mappings are deleted and the new ones are created. All mappings are deleted and the new ones are created.
See notes in _record_user_external_id_txn about what constitutes valid data.
Args: Args:
record_external_ids: record_external_ids:
List with tuple of auth_provider and external_id to record List with tuple of auth_provider and external_id to record
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
Raises: Raises:
ExternalIDReuseException if the new external_id could not be mapped. ExternalIDReuseException if the new external_id could not be mapped.
""" """
@ -1660,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id=row[1], user_id=row[1],
device_id=row[2], device_id=row[2],
next_token_id=row[3], next_token_id=row[3],
has_next_refresh_token_been_refreshed=row[4], # SQLite returns 0 or 1 for false/true, so convert to a bool.
has_next_refresh_token_been_refreshed=bool(row[4]),
# This column is nullable, ensure it's a boolean # This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False), has_next_access_token_been_used=(row[5] or False),
expiry_ts=row[6], expiry_ts=row[6],
@ -1676,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Set the successor of a refresh token, removing the existing successor Set the successor of a refresh token, removing the existing successor
if any. if any.
This also deletes the predecessor refresh and access tokens,
since they cannot be valid anymore.
Args: Args:
token_id: ID of the refresh token to update. token_id: ID of the refresh token to update.
next_token_id: ID of its successor. next_token_id: ID of its successor.
""" """
def _replace_refresh_token_txn(txn) -> None: def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
# First check if there was an existing refresh token # First check if there was an existing refresh token
old_next_token_id = self.db_pool.simple_select_one_onecol_txn( old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
@ -1707,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"id": old_next_token_id}, {"id": old_next_token_id},
) )
# Delete the previous refresh token, since we only want to keep the
# last 2 refresh tokens in the database.
# (The predecessor of the latest refresh token is still useful in
# case the refresh was interrupted and the client re-uses the old
# one.)
# This cascades to delete the associated access token.
self.db_pool.simple_delete_txn(
txn, "refresh_tokens", {"next_token_id": token_id}
)
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"replace_refresh_token", _replace_refresh_token_txn "replace_refresh_token", _replace_refresh_token_txn
) )

View File

@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation: class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool current_user_participated: bool
@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries( async def _get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]: ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event. """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
Args: Args:
event_ids: Summarize the thread related to this event ID. event_ids: Summarize the thread related to this event ID.
@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
A map of the thread summary each event. A missing event implies there A map of the thread summary each event. A missing event implies there
are no threaded replies. are no threaded replies.
Each summary includes the number of items in the thread and the most Each summary is a tuple of:
recent response. The number of events in the thread.
The most recent event in the thread.
The most recent edit to the most recent event in the thread, if applicable.
""" """
def _get_thread_summaries_txn( def _get_thread_summaries_txn(
@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Should this only allow m.room.message events. # TODO Should this only allow m.room.message events.
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters, # The `DISTINCT ON` clause will pick the *first* row it encounters,
# so ordering by topologica ordering + stream ordering desc will # so ordering by topological ordering + stream ordering desc will
# ensure we get the latest event in the thread. # ensure we get the latest event in the thread.
sql = """ sql = """
SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary. # Map to the event IDs to the thread summary.
# #
# There might not be a summary due to there not being a thread or # There might not be a summary due to there not being a thread or
@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
summary = None summary = None
if latest_event: if latest_event:
summary = (counts[parent_event_id], latest_event) latest_edit = latest_edits.get(latest_event_id)
summary = (counts[parent_event_id], latest_event, latest_edit)
summaries[parent_event_id] = summary summaries[parent_event_id] = summary
return summaries return summaries
@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
) )
for event_id, summary in summaries.items(): for event_id, summary in summaries.items():
if summary: if summary:
thread_count, latest_thread_event = summary thread_count, latest_thread_event, edit = summary
results.setdefault( results.setdefault(
event_id, BundledAggregations() event_id, BundledAggregations()
).thread = _ThreadAggregation( ).thread = _ThreadAggregation(
latest_event=latest_thread_event, latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count, count=thread_count,
# If there's a thread summary it must also exist in the # If there's a thread summary it must also exist in the
# participated dictionary. # participated dictionary.

View File

@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join( async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
) -> None: ) -> None:
"""Ensure that the room is stored in the table """Ensure that the room is stored in the table
@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
has_auth_chain_index = await self.has_auth_chain_index(room_id) has_auth_chain_index = await self.has_auth_chain_index(room_id)
create_event = None create_event = None
for e in auth_events: for e in state_events:
if (e.type, e.state_key) == (EventTypes.Create, ""): if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e create_event = e
break break

View File

@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn for room_id, instance, stream_id in txn
) )
@cachedList(
cached_method_name="get_rooms_for_user_with_stream_ordering",
list_name="user_ids",
)
async def get_rooms_for_users_with_stream_ordering(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
"""A batched version of `get_rooms_for_user_with_stream_ordering`.
Returns:
Map from user_id to set of rooms that is currently in.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_users_with_stream_ordering",
self._get_rooms_for_users_with_stream_ordering_txn,
user_ids,
)
def _get_rooms_for_users_with_stream_ordering_txn(
self, txn, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
self.database_engine,
"c.state_key",
user_ids,
)
if self._current_state_events_membership_up_to_date:
sql = f"""
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND c.membership = ?
AND {clause}
"""
else:
sql = f"""
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND m.membership = ?
AND {clause}
"""
txn.execute(sql, [Membership.JOIN] + args)
result = {user_id: set() for user_id in user_ids}
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
room_id, PersistedEventPosition(instance, stream_id)
)
)
return {user_id: frozenset(v) for user_id, v in result.items()}
async def get_users_server_still_shares_room_with( async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str] self, user_ids: Collection[str]
) -> Set[str]: ) -> Set[str]:

View File

@ -28,6 +28,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys): async def search_msgs(
self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
) -> JsonDict:
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
room_ids (list): List of room ids to search in room_ids: List of room ids to search in
search_term (str): Search term to search for search_term: Search term to search for
keys (list): List of keys to search in, currently supports keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
Returns: Returns:
list of dicts Dictionary of results
""" """
clauses = [] clauses = []
@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
self, self,
room_ids: Collection[str], room_ids: Collection[str],
search_term: str, search_term: str,
keys: List[str], keys: Iterable[str],
limit, limit,
pagination_token: Optional[str] = None, pagination_token: Optional[str] = None,
) -> List[dict]: ) -> JsonDict:
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:

View File

@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
database: DatabasePool, database: DatabasePool,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0 processed_event_count = 0
for room_id, event_count in rooms_to_work_on: for room_id, event_count in rooms_to_work_on:
is_in_room = await self.is_host_joined(room_id, self.server_name) is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
if is_in_room: if is_in_room:
users_with_profile = await self.get_users_in_room_with_profiles(room_id) users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
# Throw away users excluded from the directory. # Throw away users excluded from the directory.
users_with_profile = { users_with_profile = {
user_id: profile user_id: profile
@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id in users_to_work_on: for user_id in users_to_work_on:
if await self.should_include_local_user_in_dir(user_id): if await self.should_include_local_user_in_dir(user_id):
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
await self.update_profile_in_user_dir( await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# technically it could be DM-able. In the future, this could potentially # technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be # be configurable per-appservice whether the appservice sender can be
# contacted. # contacted.
if self.get_app_service_by_user_id(user) is not None: if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
return False return False
# We're opting to exclude appservice users (anyone matching the user # We're opting to exclude appservice users (anyone matching the user
@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# they could be DM-able. In the future, this could potentially # they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be # be configurable per-appservice whether the appservice users can be
# contacted. # contacted.
if self.get_if_app_services_interested_in_user(user): if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
# TODO we might want to make this configurable for each app service # TODO we might want to make this configurable for each app service
return False return False
# Support users are for diagnostics and should not appear in the user directory. # Support users are for diagnostics and should not appear in the user directory.
if await self.is_support_user(user): if await self.is_support_user(user): # type: ignore[attr-defined]
return False return False
# Deactivated users aren't contactable, so should not appear in the user directory. # Deactivated users aren't contactable, so should not appear in the user directory.
try: try:
if await self.get_user_deactivated_status(user): if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
return False return False
except StoreError: except StoreError:
# No such user in the users table. No need to do this when calling # No such user in the users table. No need to do this when calling
@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
) )
current_state_ids = await self.get_filtered_current_state_ids( current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter) room_id, StateFilter.from_types(types_to_filter)
) )
join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id: if join_rules_id:
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
if join_rule_ev: if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id: if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
if hist_vis_ev: if hist_vis_ev:
if ( if (
hist_vis_ev.content.get("history_visibility") hist_vis_ev.content.get("history_visibility")

View File

@ -13,11 +13,23 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
import attr import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
@ -37,8 +55,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -106,6 +124,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000, 500000,
) )
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int: def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore return txn.fetchone()[0] # type: ignore
@ -157,7 +181,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
async def _get_state_groups_from_groups( async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]: ) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the """Returns the state groups for a given set of groups from the
database, filtering on types of state events. database, filtering on types of state events.
@ -228,6 +252,165 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.
If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
and there *still* isn't enough information to complete the request by solely
reusing others, a full state filter will be requested to ensure that subsequent
requests can reuse this request.
Used as part of _get_state_for_group_using_inflight_cache.
Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""
inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over
# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
for (request_state_filter, request_deferred) in inflight_requests.items():
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue
reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break
if (
state_filter_left_over != StateFilter.none()
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
):
# There are too many requests for this group.
# To prevent even more from building up, we request the whole
# state filter to guarantee that we can be reused by any subsequent
# requests for this state group.
return (), StateFilter.all()
return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Used as part of _get_state_for_group_using_inflight_cache.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.
1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""
# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}
gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)
# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)
async def _get_state_for_groups( async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
@ -269,31 +452,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups: if not incomplete_groups:
return state return state
cache_sequence_nm = self._state_group_cache.sequence async def get_from_cache(group: int, state_filter: StateFilter) -> None:
cache_sequence_m = self._state_group_members_cache.sequence state[group] = await self._get_state_for_group_using_inflight_cache(
group, state_filter
)
# Help the cache hit ratio by expanding the filter a bit await yieldable_gather_results(
db_state_filter = state_filter.return_expanded() get_from_cache,
incomplete_groups,
group_to_state_dict = await self._get_state_groups_from_groups( state_filter,
list(incomplete_groups), state_filter=db_state_filter
) )
# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
return state return state
def _get_state_for_groups_using_cache( def _get_state_for_groups_using_cache(

Some files were not shown because too many files have changed in this diff Show More