Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
fdbccc1e74
22
CHANGES.md
22
CHANGES.md
|
@ -1,10 +1,26 @@
|
||||||
Synapse 1.28.0rc1 (2021-02-19)
|
Synapse 1.xx.0
|
||||||
==============================
|
==============
|
||||||
|
|
||||||
|
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
|
||||||
|
|
||||||
|
|
||||||
|
Synapse 1.28.0 (2021-02-25)
|
||||||
|
===========================
|
||||||
|
|
||||||
Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails).
|
Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails).
|
||||||
|
|
||||||
This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
|
This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Revert change in v1.28.0rc1 to remove the deprecated SAML endpoint. ([\#9474](https://github.com/matrix-org/synapse/issues/9474))
|
||||||
|
|
||||||
|
|
||||||
|
Synapse 1.28.0rc1 (2021-02-19)
|
||||||
|
==============================
|
||||||
|
|
||||||
Removal warning
|
Removal warning
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
|
@ -31,7 +47,7 @@ Bugfixes
|
||||||
--------
|
--------
|
||||||
|
|
||||||
- Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257))
|
- Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257))
|
||||||
- Fix bug in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296))
|
- Fix bug introduced in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296))
|
||||||
- Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393))
|
- Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393))
|
||||||
- Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333))
|
- Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333))
|
||||||
- Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361))
|
- Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361))
|
||||||
|
|
20
UPGRADE.rst
20
UPGRADE.rst
|
@ -85,6 +85,26 @@ for example:
|
||||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
|
|
||||||
|
Upgrading to v1.29.0
|
||||||
|
====================
|
||||||
|
|
||||||
|
Requirement for X-Forwarded-Proto header
|
||||||
|
----------------------------------------
|
||||||
|
|
||||||
|
When using Synapse with a reverse proxy (in particular, when using the
|
||||||
|
`x_forwarded` option on an HTTP listener), Synapse now expects to receive an
|
||||||
|
`X-Forwarded-Proto` header on incoming HTTP requests. If it is not set, Synapse
|
||||||
|
will log a warning on each received request.
|
||||||
|
|
||||||
|
To avoid the warning, administrators using a reverse proxy should ensure that
|
||||||
|
the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
|
||||||
|
indicate the protocol used by the client. See the [reverse proxy
|
||||||
|
documentation](docs/reverse_proxy.md), where the example configurations have
|
||||||
|
been updated to show how to set this header.
|
||||||
|
|
||||||
|
(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
|
||||||
|
sets `X-Forwarded-Proto` by default.)
|
||||||
|
|
||||||
Upgrading to v1.27.0
|
Upgrading to v1.27.0
|
||||||
====================
|
====================
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug where users' pushers were not all deleted when they deactivated their account.
|
|
@ -0,0 +1 @@
|
||||||
|
Added a fix that invalidates cache for empty timed-out sync responses.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug in single sign-on which could cause a "No session cookie found" error.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`.
|
|
@ -0,0 +1 @@
|
||||||
|
Remove vestiges of `uploads_path` configuration setting.
|
|
@ -0,0 +1 @@
|
||||||
|
Update the example systemd config to propagate reloads to individual units.
|
|
@ -0,0 +1 @@
|
||||||
|
Add a comment about systemd-python.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix deleting pushers when using sharded pushers.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix deleting pushers when using sharded pushers.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix missing startup checks for the consistency of certain PostgreSQL sequences.
|
|
@ -0,0 +1 @@
|
||||||
|
Add support for `X-Forwarded-Proto` header when using a reverse proxy.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix deleting pushers when using sharded pushers.
|
|
@ -0,0 +1 @@
|
||||||
|
Test that we require validated email for email pushers.
|
|
@ -0,0 +1 @@
|
||||||
|
Add support for `X-Forwarded-Proto` header when using a reverse proxy.
|
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.28.0) stable; urgency=medium
|
||||||
|
|
||||||
|
* New synapse release 1.28.0.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Thu, 25 Feb 2021 10:21:57 +0000
|
||||||
|
|
||||||
matrix-synapse-py3 (1.27.0) stable; urgency=medium
|
matrix-synapse-py3 (1.27.0) stable; urgency=medium
|
||||||
|
|
||||||
[ Dan Callahan ]
|
[ Dan Callahan ]
|
||||||
|
|
|
@ -11,7 +11,6 @@ The image also does *not* provide a TURN server.
|
||||||
By default, the image expects a single volume, located at ``/data``, that will hold:
|
By default, the image expects a single volume, located at ``/data``, that will hold:
|
||||||
|
|
||||||
* configuration files;
|
* configuration files;
|
||||||
* temporary files during uploads;
|
|
||||||
* uploaded media and thumbnails;
|
* uploaded media and thumbnails;
|
||||||
* the SQLite database if you do not configure postgres;
|
* the SQLite database if you do not configure postgres;
|
||||||
* the appservices configuration.
|
* the appservices configuration.
|
||||||
|
|
|
@ -89,7 +89,6 @@ federation_rc_concurrent: 3
|
||||||
## Files ##
|
## Files ##
|
||||||
|
|
||||||
media_store_path: "/data/media"
|
media_store_path: "/data/media"
|
||||||
uploads_path: "/data/uploads"
|
|
||||||
max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
|
max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
|
||||||
max_image_pixels: "32M"
|
max_image_pixels: "32M"
|
||||||
dynamic_thumbnails: false
|
dynamic_thumbnails: false
|
||||||
|
|
|
@ -379,11 +379,12 @@ The following fields are returned in the JSON response body:
|
||||||
- ``total`` - Number of rooms.
|
- ``total`` - Number of rooms.
|
||||||
|
|
||||||
|
|
||||||
List media of an user
|
List media of a user
|
||||||
================================
|
====================
|
||||||
Gets a list of all local media that a specific ``user_id`` has created.
|
Gets a list of all local media that a specific ``user_id`` has created.
|
||||||
The response is ordered by creation date descending and media ID descending.
|
By default, the response is ordered by descending creation date and ascending media ID.
|
||||||
The newest media is on top.
|
The newest media is on top. You can change the order with parameters
|
||||||
|
``order_by`` and ``dir``.
|
||||||
|
|
||||||
The API is::
|
The API is::
|
||||||
|
|
||||||
|
@ -440,6 +441,35 @@ The following parameters should be set in the URL:
|
||||||
denoting the offset in the returned results. This should be treated as an opaque value and
|
denoting the offset in the returned results. This should be treated as an opaque value and
|
||||||
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
|
not explicitly set to anything other than the return value of ``next_token`` from a previous call.
|
||||||
Defaults to ``0``.
|
Defaults to ``0``.
|
||||||
|
- ``order_by`` - The method by which to sort the returned list of media.
|
||||||
|
If the ordered field has duplicates, the second order is always by ascending ``media_id``,
|
||||||
|
which guarantees a stable ordering. Valid values are:
|
||||||
|
|
||||||
|
- ``media_id`` - Media are ordered alphabetically by ``media_id``.
|
||||||
|
- ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with.
|
||||||
|
- ``created_ts`` - Media are ordered by when the content was uploaded in ms.
|
||||||
|
Smallest to largest. This is the default.
|
||||||
|
- ``last_access_ts`` - Media are ordered by when the content was last accessed in ms.
|
||||||
|
Smallest to largest.
|
||||||
|
- ``media_length`` - Media are ordered by length of the media in bytes.
|
||||||
|
Smallest to largest.
|
||||||
|
- ``media_type`` - Media are ordered alphabetically by MIME-type.
|
||||||
|
- ``quarantined_by`` - Media are ordered alphabetically by the user ID that
|
||||||
|
initiated the quarantine request for this media.
|
||||||
|
- ``safe_from_quarantine`` - Media are ordered by the status if this media is safe
|
||||||
|
from quarantining.
|
||||||
|
|
||||||
|
- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
|
||||||
|
Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
|
||||||
|
|
||||||
|
If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top
|
||||||
|
(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``).
|
||||||
|
|
||||||
|
Caution. The database only has indexes on the columns ``media_id``,
|
||||||
|
``user_id`` and ``created_ts``. This means that if a different sort order is used
|
||||||
|
(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``,
|
||||||
|
``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the
|
||||||
|
database, especially for large environments.
|
||||||
|
|
||||||
**Response**
|
**Response**
|
||||||
|
|
||||||
|
|
|
@ -9,23 +9,23 @@ of doing so is that it means that you can expose the default https port
|
||||||
(443) to Matrix clients without needing to run Synapse with root
|
(443) to Matrix clients without needing to run Synapse with root
|
||||||
privileges.
|
privileges.
|
||||||
|
|
||||||
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
|
You should configure your reverse proxy to forward requests to `/_matrix` or
|
||||||
the requested URI in any way (for example, by decoding `%xx` escapes).
|
`/_synapse/client` to Synapse, and have it set the `X-Forwarded-For` and
|
||||||
Beware that Apache *will* canonicalise URIs unless you specify
|
`X-Forwarded-Proto` request headers.
|
||||||
`nocanon`.
|
|
||||||
|
|
||||||
When setting up a reverse proxy, remember that Matrix clients and other
|
You should remember that Matrix clients and other Matrix servers do not
|
||||||
Matrix servers do not necessarily need to connect to your server via the
|
necessarily need to connect to your server via the same server name or
|
||||||
same server name or port. Indeed, clients will use port 443 by default,
|
port. Indeed, clients will use port 443 by default, whereas servers default to
|
||||||
whereas servers default to port 8448. Where these are different, we
|
port 8448. Where these are different, we refer to the 'client port' and the
|
||||||
refer to the 'client port' and the 'federation port'. See [the Matrix
|
'federation port'. See [the Matrix
|
||||||
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
|
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
|
||||||
for more details of the algorithm used for federation connections, and
|
for more details of the algorithm used for federation connections, and
|
||||||
[delegate.md](<delegate.md>) for instructions on setting up delegation.
|
[delegate.md](<delegate.md>) for instructions on setting up delegation.
|
||||||
|
|
||||||
Endpoints that are part of the standardised Matrix specification are
|
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
|
||||||
located under `/_matrix`, whereas endpoints specific to Synapse are
|
the requested URI in any way (for example, by decoding `%xx` escapes).
|
||||||
located under `/_synapse/client`.
|
Beware that Apache *will* canonicalise URIs unless you specify
|
||||||
|
`nocanon`.
|
||||||
|
|
||||||
Let's assume that we expect clients to connect to our server at
|
Let's assume that we expect clients to connect to our server at
|
||||||
`https://matrix.example.com`, and other servers to connect at
|
`https://matrix.example.com`, and other servers to connect at
|
||||||
|
@ -52,6 +52,7 @@ server {
|
||||||
location ~* ^(\/_matrix|\/_synapse\/client) {
|
location ~* ^(\/_matrix|\/_synapse\/client) {
|
||||||
proxy_pass http://localhost:8008;
|
proxy_pass http://localhost:8008;
|
||||||
proxy_set_header X-Forwarded-For $remote_addr;
|
proxy_set_header X-Forwarded-For $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
# Nginx by default only allows file uploads up to 1M in size
|
# Nginx by default only allows file uploads up to 1M in size
|
||||||
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
|
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
|
||||||
client_max_body_size 50M;
|
client_max_body_size 50M;
|
||||||
|
@ -102,6 +103,7 @@ example.com:8448 {
|
||||||
SSLEngine on
|
SSLEngine on
|
||||||
ServerName matrix.example.com;
|
ServerName matrix.example.com;
|
||||||
|
|
||||||
|
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
|
||||||
AllowEncodedSlashes NoDecode
|
AllowEncodedSlashes NoDecode
|
||||||
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
|
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
|
||||||
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
|
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
|
||||||
|
@ -113,6 +115,7 @@ example.com:8448 {
|
||||||
SSLEngine on
|
SSLEngine on
|
||||||
ServerName example.com;
|
ServerName example.com;
|
||||||
|
|
||||||
|
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
|
||||||
AllowEncodedSlashes NoDecode
|
AllowEncodedSlashes NoDecode
|
||||||
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
|
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
|
||||||
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
|
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
|
||||||
|
@ -134,6 +137,9 @@ example.com:8448 {
|
||||||
```
|
```
|
||||||
frontend https
|
frontend https
|
||||||
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
||||||
|
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
||||||
|
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
||||||
|
http-request set-header X-Forwarded-For %[src]
|
||||||
|
|
||||||
# Matrix client traffic
|
# Matrix client traffic
|
||||||
acl matrix-host hdr(host) -i matrix.example.com
|
acl matrix-host hdr(host) -i matrix.example.com
|
||||||
|
@ -144,6 +150,10 @@ frontend https
|
||||||
|
|
||||||
frontend matrix-federation
|
frontend matrix-federation
|
||||||
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
||||||
|
http-request set-header X-Forwarded-Proto https if { ssl_fc }
|
||||||
|
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
|
||||||
|
http-request set-header X-Forwarded-For %[src]
|
||||||
|
|
||||||
default_backend matrix
|
default_backend matrix
|
||||||
|
|
||||||
backend matrix
|
backend matrix
|
||||||
|
|
|
@ -25,7 +25,7 @@ well as some specific methods:
|
||||||
* `check_username_for_spam`
|
* `check_username_for_spam`
|
||||||
* `check_registration_for_spam`
|
* `check_registration_for_spam`
|
||||||
|
|
||||||
The details of the each of these methods (as well as their inputs and outputs)
|
The details of each of these methods (as well as their inputs and outputs)
|
||||||
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||||
|
|
||||||
The `ModuleApi` class provides a way for the custom spam checker class to
|
The `ModuleApi` class provides a way for the custom spam checker class to
|
||||||
|
|
|
@ -4,6 +4,7 @@ AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
|
||||||
|
|
||||||
# This service should be restarted when the synapse target is restarted.
|
# This service should be restarted when the synapse target is restarted.
|
||||||
PartOf=matrix-synapse.target
|
PartOf=matrix-synapse.target
|
||||||
|
ReloadPropagatedFrom=matrix-synapse.target
|
||||||
|
|
||||||
# if this is started at the same time as the main, let the main process start
|
# if this is started at the same time as the main, let the main process start
|
||||||
# first, to initialise the database schema.
|
# first, to initialise the database schema.
|
||||||
|
|
|
@ -3,6 +3,7 @@ Description=Synapse master
|
||||||
|
|
||||||
# This service should be restarted when the synapse target is restarted.
|
# This service should be restarted when the synapse target is restarted.
|
||||||
PartOf=matrix-synapse.target
|
PartOf=matrix-synapse.target
|
||||||
|
ReloadPropagatedFrom=matrix-synapse.target
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=notify
|
Type=notify
|
||||||
|
|
|
@ -220,10 +220,6 @@ Asks the server for the current position of all streams.
|
||||||
|
|
||||||
Acknowledge receipt of some federation data
|
Acknowledge receipt of some federation data
|
||||||
|
|
||||||
#### REMOVE_PUSHER (C)
|
|
||||||
|
|
||||||
Inform the server a pusher should be removed
|
|
||||||
|
|
||||||
### REMOTE_SERVER_UP (S, C)
|
### REMOTE_SERVER_UP (S, C)
|
||||||
|
|
||||||
Inform other processes that a remote server may have come back online.
|
Inform other processes that a remote server may have come back online.
|
||||||
|
|
|
@ -22,7 +22,7 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Optional, Set
|
from typing import Dict, Iterable, Optional, Set
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -629,7 +629,13 @@ class Porter(object):
|
||||||
await self._setup_state_group_id_seq()
|
await self._setup_state_group_id_seq()
|
||||||
await self._setup_user_id_seq()
|
await self._setup_user_id_seq()
|
||||||
await self._setup_events_stream_seqs()
|
await self._setup_events_stream_seqs()
|
||||||
await self._setup_device_inbox_seq()
|
await self._setup_sequence(
|
||||||
|
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data"))
|
||||||
|
await self._setup_sequence("receipts_sequence", ("receipts_linearized", ))
|
||||||
|
await self._setup_auth_chain_sequence()
|
||||||
|
|
||||||
# Step 3. Get tables.
|
# Step 3. Get tables.
|
||||||
self.progress.set_state("Fetching tables")
|
self.progress.set_state("Fetching tables")
|
||||||
|
@ -854,7 +860,7 @@ class Porter(object):
|
||||||
|
|
||||||
return done, remaining + done
|
return done, remaining + done
|
||||||
|
|
||||||
async def _setup_state_group_id_seq(self):
|
async def _setup_state_group_id_seq(self) -> None:
|
||||||
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
||||||
)
|
)
|
||||||
|
@ -868,7 +874,7 @@ class Porter(object):
|
||||||
|
|
||||||
await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
|
await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
|
||||||
|
|
||||||
async def _setup_user_id_seq(self):
|
async def _setup_user_id_seq(self) -> None:
|
||||||
curr_id = await self.sqlite_store.db_pool.runInteraction(
|
curr_id = await self.sqlite_store.db_pool.runInteraction(
|
||||||
"setup_user_id_seq", find_max_generated_user_id_localpart
|
"setup_user_id_seq", find_max_generated_user_id_localpart
|
||||||
)
|
)
|
||||||
|
@ -877,9 +883,9 @@ class Porter(object):
|
||||||
next_id = curr_id + 1
|
next_id = curr_id + 1
|
||||||
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
|
||||||
|
|
||||||
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
|
||||||
|
|
||||||
async def _setup_events_stream_seqs(self):
|
async def _setup_events_stream_seqs(self) -> None:
|
||||||
"""Set the event stream sequences to the correct values.
|
"""Set the event stream sequences to the correct values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -908,35 +914,46 @@ class Porter(object):
|
||||||
(curr_backward_id + 1,),
|
(curr_backward_id + 1,),
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.postgres_store.db_pool.runInteraction(
|
await self.postgres_store.db_pool.runInteraction(
|
||||||
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _setup_device_inbox_seq(self):
|
async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
|
||||||
"""Set the device inbox sequence to the correct value.
|
"""Set a sequence to the correct value.
|
||||||
"""
|
"""
|
||||||
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
current_stream_ids = []
|
||||||
table="device_inbox",
|
for stream_id_table in stream_id_tables:
|
||||||
keyvalues={},
|
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
retcol="COALESCE(MAX(stream_id), 1)",
|
table=stream_id_table,
|
||||||
allow_none=True,
|
keyvalues={},
|
||||||
)
|
retcol="COALESCE(MAX(stream_id), 1)",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
current_stream_ids.append(max_stream_id)
|
||||||
|
|
||||||
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
next_id = max(current_stream_ids) + 1
|
||||||
table="device_federation_outbox",
|
|
||||||
keyvalues={},
|
|
||||||
retcol="COALESCE(MAX(stream_id), 1)",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_id = max(curr_local_id, curr_federation_id) + 1
|
def r(txn):
|
||||||
|
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
|
||||||
|
txn.execute(sql + " %s", (next_id, ))
|
||||||
|
|
||||||
|
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
|
||||||
|
|
||||||
|
async def _setup_auth_chain_sequence(self) -> None:
|
||||||
|
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
def r(txn):
|
def r(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
|
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
|
||||||
|
(curr_chain_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
|
await self.postgres_store.db_pool.runInteraction(
|
||||||
|
"_setup_event_auth_chain_id", r,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
|
|
|
@ -48,7 +48,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.28.0rc1"
|
__version__ = "1.28.0"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -210,7 +210,9 @@ def start(config_options):
|
||||||
config.update_user_directory = False
|
config.update_user_directory = False
|
||||||
config.run_background_tasks = False
|
config.run_background_tasks = False
|
||||||
config.start_pushers = False
|
config.start_pushers = False
|
||||||
|
config.pusher_shard_config.instances = []
|
||||||
config.send_federation = False
|
config.send_federation = False
|
||||||
|
config.federation_shard_config.instances = []
|
||||||
|
|
||||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
|
|
|
@ -645,9 +645,6 @@ class GenericWorkerServer(HomeServer):
|
||||||
|
|
||||||
self.get_tcp_replication().start_replication(self)
|
self.get_tcp_replication().start_replication(self)
|
||||||
|
|
||||||
async def remove_pusher(self, app_id, push_key, user_id):
|
|
||||||
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_replication_data_handler(self):
|
def get_replication_data_handler(self):
|
||||||
return GenericWorkerReplicationHandler(self)
|
return GenericWorkerReplicationHandler(self)
|
||||||
|
@ -922,22 +919,6 @@ def start(config_options):
|
||||||
# For other worker types we force this to off.
|
# For other worker types we force this to off.
|
||||||
config.appservice.notify_appservices = False
|
config.appservice.notify_appservices = False
|
||||||
|
|
||||||
if config.worker_app == "synapse.app.pusher":
|
|
||||||
if config.server.start_pushers:
|
|
||||||
sys.stderr.write(
|
|
||||||
"\nThe pushers must be disabled in the main synapse process"
|
|
||||||
"\nbefore they can be run in a separate worker."
|
|
||||||
"\nPlease add ``start_pushers: false`` to the main config"
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Force the pushers to start since they will be disabled in the main config
|
|
||||||
config.server.start_pushers = True
|
|
||||||
else:
|
|
||||||
# For other worker types we force this to off.
|
|
||||||
config.server.start_pushers = False
|
|
||||||
|
|
||||||
if config.worker_app == "synapse.app.user_dir":
|
if config.worker_app == "synapse.app.user_dir":
|
||||||
if config.server.update_user_directory:
|
if config.server.update_user_directory:
|
||||||
sys.stderr.write(
|
sys.stderr.write(
|
||||||
|
@ -954,22 +935,6 @@ def start(config_options):
|
||||||
# For other worker types we force this to off.
|
# For other worker types we force this to off.
|
||||||
config.server.update_user_directory = False
|
config.server.update_user_directory = False
|
||||||
|
|
||||||
if config.worker_app == "synapse.app.federation_sender":
|
|
||||||
if config.worker.send_federation:
|
|
||||||
sys.stderr.write(
|
|
||||||
"\nThe send_federation must be disabled in the main synapse process"
|
|
||||||
"\nbefore they can be run in a separate worker."
|
|
||||||
"\nPlease add ``send_federation: false`` to the main config"
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Force the pushers to start since they will be disabled in the main config
|
|
||||||
config.worker.send_federation = True
|
|
||||||
else:
|
|
||||||
# For other worker types we force this to off.
|
|
||||||
config.worker.send_federation = False
|
|
||||||
|
|
||||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
||||||
hs = GenericWorkerServer(
|
hs = GenericWorkerServer(
|
||||||
|
|
|
@ -844,22 +844,23 @@ class ShardedWorkerHandlingConfig:
|
||||||
|
|
||||||
def should_handle(self, instance_name: str, key: str) -> bool:
|
def should_handle(self, instance_name: str, key: str) -> bool:
|
||||||
"""Whether this instance is responsible for handling the given key."""
|
"""Whether this instance is responsible for handling the given key."""
|
||||||
# If multiple instances are not defined we always return true
|
# If no instances are defined we assume some other worker is handling
|
||||||
if not self.instances or len(self.instances) == 1:
|
# this.
|
||||||
return True
|
if not self.instances:
|
||||||
|
return False
|
||||||
|
|
||||||
return self.get_instance(key) == instance_name
|
return self._get_instance(key) == instance_name
|
||||||
|
|
||||||
def get_instance(self, key: str) -> str:
|
def _get_instance(self, key: str) -> str:
|
||||||
"""Get the instance responsible for handling the given key.
|
"""Get the instance responsible for handling the given key.
|
||||||
|
|
||||||
Note: For things like federation sending the config for which instance
|
Note: For federation sending and pushers the config for which instance
|
||||||
is sending is known only to the sender instance if there is only one.
|
is sending is known only to the sender instance, so we don't expose this
|
||||||
Therefore `should_handle` should be used where possible.
|
method by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.instances:
|
if not self.instances:
|
||||||
return "master"
|
raise Exception("Unknown worker")
|
||||||
|
|
||||||
if len(self.instances) == 1:
|
if len(self.instances) == 1:
|
||||||
return self.instances[0]
|
return self.instances[0]
|
||||||
|
@ -876,4 +877,21 @@ class ShardedWorkerHandlingConfig:
|
||||||
return self.instances[remainder]
|
return self.instances[remainder]
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
|
"""A version of `ShardedWorkerHandlingConfig` that is used for config
|
||||||
|
options where all instances know which instances are responsible for the
|
||||||
|
sharded work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __attrs_post_init__(self):
|
||||||
|
# We require that `self.instances` is non-empty.
|
||||||
|
if not self.instances:
|
||||||
|
raise Exception("Got empty list of instances for shard config")
|
||||||
|
|
||||||
|
def get_instance(self, key: str) -> str:
|
||||||
|
"""Get the instance responsible for handling the given key."""
|
||||||
|
return self._get_instance(key)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
||||||
|
|
|
@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig:
|
||||||
instances: List[str]
|
instances: List[str]
|
||||||
def __init__(self, instances: List[str]) -> None: ...
|
def __init__(self, instances: List[str]) -> None: ...
|
||||||
def should_handle(self, instance_name: str, key: str) -> bool: ...
|
def should_handle(self, instance_name: str, key: str) -> bool: ...
|
||||||
|
|
||||||
|
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
def get_instance(self, key: str) -> str: ...
|
def get_instance(self, key: str) -> str: ...
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import Config, ShardedWorkerHandlingConfig
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
class PushConfig(Config):
|
class PushConfig(Config):
|
||||||
|
@ -27,9 +27,6 @@ class PushConfig(Config):
|
||||||
"group_unread_count_by_room", True
|
"group_unread_count_by_room", True
|
||||||
)
|
)
|
||||||
|
|
||||||
pusher_instances = config.get("pusher_instances") or []
|
|
||||||
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
|
|
||||||
|
|
||||||
# There was a a 'redact_content' setting but mistakenly read from the
|
# There was a a 'redact_content' setting but mistakenly read from the
|
||||||
# 'email'section'. Check for the flag in the 'push' section, and log,
|
# 'email'section'. Check for the flag in the 'push' section, and log,
|
||||||
# but do not honour it to avoid nasty surprises when people upgrade.
|
# but do not honour it to avoid nasty surprises when people upgrade.
|
||||||
|
|
|
@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config):
|
||||||
|
|
||||||
def generate_config_section(self, data_dir_path, **kwargs):
|
def generate_config_section(self, data_dir_path, **kwargs):
|
||||||
media_store = os.path.join(data_dir_path, "media_store")
|
media_store = os.path.join(data_dir_path, "media_store")
|
||||||
uploads_path = os.path.join(data_dir_path, "uploads")
|
|
||||||
|
|
||||||
formatted_thumbnail_sizes = "".join(
|
formatted_thumbnail_sizes = "".join(
|
||||||
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES
|
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES
|
||||||
|
|
|
@ -397,7 +397,6 @@ class ServerConfig(Config):
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
if self.public_baseurl[-1] != "/":
|
if self.public_baseurl[-1] != "/":
|
||||||
self.public_baseurl += "/"
|
self.public_baseurl += "/"
|
||||||
self.start_pushers = config.get("start_pushers", True)
|
|
||||||
|
|
||||||
# (undocumented) option for torturing the worker-mode replication a bit,
|
# (undocumented) option for torturing the worker-mode replication a bit,
|
||||||
# for testing. The value defines the number of milliseconds to pause before
|
# for testing. The value defines the number of milliseconds to pause before
|
||||||
|
|
|
@ -17,9 +17,28 @@ from typing import List, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
|
from ._base import (
|
||||||
|
Config,
|
||||||
|
ConfigError,
|
||||||
|
RoutableShardedWorkerHandlingConfig,
|
||||||
|
ShardedWorkerHandlingConfig,
|
||||||
|
)
|
||||||
from .server import ListenerConfig, parse_listener_def
|
from .server import ListenerConfig, parse_listener_def
|
||||||
|
|
||||||
|
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
|
||||||
|
The send_federation config option must be disabled in the main
|
||||||
|
synapse process before they can be run in a separate worker.
|
||||||
|
|
||||||
|
Please add ``send_federation: false`` to the main config
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """
|
||||||
|
The start_pushers config option must be disabled in the main
|
||||||
|
synapse process before they can be run in a separate worker.
|
||||||
|
|
||||||
|
Please add ``start_pushers: false`` to the main config
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
||||||
"""Helper for allowing parsing a string or list of strings to a config
|
"""Helper for allowing parsing a string or list of strings to a config
|
||||||
|
@ -103,6 +122,7 @@ class WorkerConfig(Config):
|
||||||
self.worker_replication_secret = config.get("worker_replication_secret", None)
|
self.worker_replication_secret = config.get("worker_replication_secret", None)
|
||||||
|
|
||||||
self.worker_name = config.get("worker_name", self.worker_app)
|
self.worker_name = config.get("worker_name", self.worker_app)
|
||||||
|
self.instance_name = self.worker_name or "master"
|
||||||
|
|
||||||
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
||||||
|
|
||||||
|
@ -118,12 +138,41 @@ class WorkerConfig(Config):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Whether to send federation traffic out in this process. This only
|
# Handle federation sender configuration.
|
||||||
# applies to some federation traffic, and so shouldn't be used to
|
#
|
||||||
# "disable" federation
|
# There are two ways of configuring which instances handle federation
|
||||||
self.send_federation = config.get("send_federation", True)
|
# sending:
|
||||||
|
# 1. The old way where "send_federation" is set to false and running a
|
||||||
|
# `synapse.app.federation_sender` worker app.
|
||||||
|
# 2. Specifying the workers sending federation in
|
||||||
|
# `federation_sender_instances`.
|
||||||
|
#
|
||||||
|
|
||||||
federation_sender_instances = config.get("federation_sender_instances") or []
|
send_federation = config.get("send_federation", True)
|
||||||
|
|
||||||
|
federation_sender_instances = config.get("federation_sender_instances")
|
||||||
|
if federation_sender_instances is None:
|
||||||
|
# Default to an empty list, which means "another, unknown, worker is
|
||||||
|
# responsible for it".
|
||||||
|
federation_sender_instances = []
|
||||||
|
|
||||||
|
# If no federation sender instances are set we check if
|
||||||
|
# `send_federation` is set, which means use master
|
||||||
|
if send_federation:
|
||||||
|
federation_sender_instances = ["master"]
|
||||||
|
|
||||||
|
if self.worker_app == "synapse.app.federation_sender":
|
||||||
|
if send_federation:
|
||||||
|
# If we're running federation senders, and not using
|
||||||
|
# `federation_sender_instances`, then we should have
|
||||||
|
# explicitly set `send_federation` to false.
|
||||||
|
raise ConfigError(
|
||||||
|
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
federation_sender_instances = [self.worker_name]
|
||||||
|
|
||||||
|
self.send_federation = self.instance_name in federation_sender_instances
|
||||||
self.federation_shard_config = ShardedWorkerHandlingConfig(
|
self.federation_shard_config = ShardedWorkerHandlingConfig(
|
||||||
federation_sender_instances
|
federation_sender_instances
|
||||||
)
|
)
|
||||||
|
@ -164,7 +213,37 @@ class WorkerConfig(Config):
|
||||||
"Must only specify one instance to handle `receipts` messages."
|
"Must only specify one instance to handle `receipts` messages."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
if len(self.writers.events) == 0:
|
||||||
|
raise ConfigError("Must specify at least one instance to handle `events`.")
|
||||||
|
|
||||||
|
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
|
||||||
|
self.writers.events
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle sharded push
|
||||||
|
start_pushers = config.get("start_pushers", True)
|
||||||
|
pusher_instances = config.get("pusher_instances")
|
||||||
|
if pusher_instances is None:
|
||||||
|
# Default to an empty list, which means "another, unknown, worker is
|
||||||
|
# responsible for it".
|
||||||
|
pusher_instances = []
|
||||||
|
|
||||||
|
# If no pushers instances are set we check if `start_pushers` is
|
||||||
|
# set, which means use master
|
||||||
|
if start_pushers:
|
||||||
|
pusher_instances = ["master"]
|
||||||
|
|
||||||
|
if self.worker_app == "synapse.app.pusher":
|
||||||
|
if start_pushers:
|
||||||
|
# If we're running pushers, and not using
|
||||||
|
# `pusher_instances`, then we should have explicitly set
|
||||||
|
# `start_pushers` to false.
|
||||||
|
raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR)
|
||||||
|
|
||||||
|
pusher_instances = [self.instance_name]
|
||||||
|
|
||||||
|
self.start_pushers = self.instance_name in pusher_instances
|
||||||
|
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
|
||||||
|
|
||||||
# Whether this worker should run background tasks or not.
|
# Whether this worker should run background tasks or not.
|
||||||
#
|
#
|
||||||
|
|
|
@ -120,6 +120,11 @@ class DeactivateAccountHandler(BaseHandler):
|
||||||
|
|
||||||
await self.store.user_set_password_hash(user_id, None)
|
await self.store.user_set_password_hash(user_id, None)
|
||||||
|
|
||||||
|
# Most of the pushers will have been deleted when we logged out the
|
||||||
|
# associated devices above, but we still need to delete pushers not
|
||||||
|
# associated with devices, e.g. email pushers.
|
||||||
|
await self.store.delete_all_pushers_for_user(user_id)
|
||||||
|
|
||||||
# Add the user to a table of users pending deactivation (ie.
|
# Add the user to a table of users pending deactivation (ie.
|
||||||
# removal from all the rooms they're a member of)
|
# removal from all the rooms they're a member of)
|
||||||
await self.store.add_user_pending_deactivation(user_id)
|
await self.store.add_user_pending_deactivation(user_id)
|
||||||
|
|
|
@ -278,8 +278,9 @@ class SyncHandler:
|
||||||
user_id = sync_config.user.to_string()
|
user_id = sync_config.user.to_string()
|
||||||
await self.auth.check_auth_blocking(requester=requester)
|
await self.auth.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
res = await self.response_cache.wrap(
|
res = await self.response_cache.wrap_conditional(
|
||||||
sync_config.request_key,
|
sync_config.request_key,
|
||||||
|
lambda result: since_token != result.next_batch,
|
||||||
self._wait_for_sync_for_user,
|
self._wait_for_sync_for_user,
|
||||||
sync_config,
|
sync_config,
|
||||||
since_token,
|
since_token,
|
||||||
|
|
|
@ -14,8 +14,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from twisted.internet import task
|
from twisted.internet import address, task
|
||||||
from twisted.web.client import FileBodyProducer
|
from twisted.web.client import FileBodyProducer
|
||||||
from twisted.web.iweb import IRequest
|
from twisted.web.iweb import IRequest
|
||||||
|
|
||||||
|
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_uri(request: IRequest) -> bytes:
|
||||||
|
"""Return the full URI that was requested by the client"""
|
||||||
|
return b"%s://%s%s" % (
|
||||||
|
b"https" if request.isSecure() else b"http",
|
||||||
|
_get_requested_host(request),
|
||||||
|
# despite its name, "request.uri" is only the path and query-string.
|
||||||
|
request.uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_requested_host(request: IRequest) -> bytes:
|
||||||
|
hostname = request.getHeader(b"host")
|
||||||
|
if hostname:
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
# no Host header, use the address/port that the request arrived on
|
||||||
|
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
|
||||||
|
|
||||||
|
hostname = host.host.encode("ascii")
|
||||||
|
|
||||||
|
if request.isSecure() and host.port == 443:
|
||||||
|
# default port for https
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
if not request.isSecure() and host.port == 80:
|
||||||
|
# default port for http
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
return b"%s:%i" % (
|
||||||
|
hostname,
|
||||||
|
host.port,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
||||||
"""Return the last User-Agent header, or the given default."""
|
"""Return the last User-Agent header, or the given default."""
|
||||||
# There could be raw utf-8 bytes in the User-Agent header.
|
# There could be raw utf-8 bytes in the User-Agent header.
|
||||||
|
|
|
@ -16,6 +16,10 @@ import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet.interfaces import IAddress
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
|
@ -333,27 +337,78 @@ class SynapseRequest(Request):
|
||||||
|
|
||||||
|
|
||||||
class XForwardedForRequest(SynapseRequest):
|
class XForwardedForRequest(SynapseRequest):
|
||||||
def __init__(self, *args, **kw):
|
"""Request object which honours proxy headers
|
||||||
SynapseRequest.__init__(self, *args, **kw)
|
|
||||||
|
|
||||||
"""
|
Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
|
||||||
Add a layer on top of another request that only uses the value of an
|
information from request headers.
|
||||||
X-Forwarded-For header as the result of C{getClientIP}.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def getClientIP(self):
|
# the client IP and ssl flag, as extracted from the headers.
|
||||||
"""
|
_forwarded_for = None # type: Optional[_XForwardedForAddress]
|
||||||
@return: The client address (the first address) in the value of the
|
_forwarded_https = False # type: bool
|
||||||
I{X-Forwarded-For header}. If the header is not present, return
|
|
||||||
C{b"-"}.
|
def requestReceived(self, command, path, version):
|
||||||
"""
|
# this method is called by the Channel once the full request has been
|
||||||
return (
|
# received, to dispatch the request to a resource.
|
||||||
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
|
# We can use it to set the IP address and protocol according to the
|
||||||
.split(b",")[0]
|
# headers.
|
||||||
.strip()
|
self._process_forwarded_headers()
|
||||||
.decode("ascii")
|
return super().requestReceived(command, path, version)
|
||||||
|
|
||||||
|
def _process_forwarded_headers(self):
|
||||||
|
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
|
||||||
|
if not headers:
|
||||||
|
return
|
||||||
|
|
||||||
|
# for now, we just use the first x-forwarded-for header. Really, we ought
|
||||||
|
# to start from the client IP address, and check whether it is trusted; if it
|
||||||
|
# is, work backwards through the headers until we find an untrusted address.
|
||||||
|
# see https://github.com/matrix-org/synapse/issues/9471
|
||||||
|
self._forwarded_for = _XForwardedForAddress(
|
||||||
|
headers[0].split(b",")[0].strip().decode("ascii")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if we got an x-forwarded-for header, also look for an x-forwarded-proto header
|
||||||
|
header = self.getHeader(b"x-forwarded-proto")
|
||||||
|
if header is not None:
|
||||||
|
self._forwarded_https = header.lower() == b"https"
|
||||||
|
else:
|
||||||
|
# this is done largely for backwards-compatibility so that people that
|
||||||
|
# haven't set an x-forwarded-proto header don't get a redirect loop.
|
||||||
|
logger.warning(
|
||||||
|
"forwarded request lacks an x-forwarded-proto header: assuming https"
|
||||||
|
)
|
||||||
|
self._forwarded_https = True
|
||||||
|
|
||||||
|
def isSecure(self):
|
||||||
|
if self._forwarded_https:
|
||||||
|
return True
|
||||||
|
return super().isSecure()
|
||||||
|
|
||||||
|
def getClientIP(self) -> str:
|
||||||
|
"""
|
||||||
|
Return the IP address of the client who submitted this request.
|
||||||
|
|
||||||
|
This method is deprecated. Use getClientAddress() instead.
|
||||||
|
"""
|
||||||
|
if self._forwarded_for is not None:
|
||||||
|
return self._forwarded_for.host
|
||||||
|
return super().getClientIP()
|
||||||
|
|
||||||
|
def getClientAddress(self) -> IAddress:
|
||||||
|
"""
|
||||||
|
Return the address of the client who submitted this request.
|
||||||
|
"""
|
||||||
|
if self._forwarded_for is not None:
|
||||||
|
return self._forwarded_for
|
||||||
|
return super().getClientAddress()
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IAddress)
|
||||||
|
@attr.s(frozen=True, slots=True)
|
||||||
|
class _XForwardedForAddress:
|
||||||
|
host = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
class SynapseSite(Site):
|
class SynapseSite(Site):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -74,6 +74,7 @@ class HttpPusher(Pusher):
|
||||||
self.timed_call = None
|
self.timed_call = None
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||||
|
self._pusherpool = hs.get_pusherpool()
|
||||||
|
|
||||||
self.data = pusher_config.data
|
self.data = pusher_config.data
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
|
@ -304,7 +305,7 @@ class HttpPusher(Pusher):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Pushkey %s was rejected: removing", pk)
|
logger.info("Pushkey %s was rejected: removing", pk)
|
||||||
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
|
await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _build_notification_dict(
|
async def _build_notification_dict(
|
||||||
|
|
|
@ -19,12 +19,14 @@ from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
||||||
from synapse.push.pusher import PusherFactory
|
from synapse.push.pusher import PusherFactory
|
||||||
|
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
|
||||||
from synapse.types import JsonDict, RoomStreamToken
|
from synapse.types import JsonDict, RoomStreamToken
|
||||||
from synapse.util.async_helpers import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
|
|
||||||
|
@ -58,7 +60,6 @@ class PusherPool:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.pusher_factory = PusherFactory(hs)
|
self.pusher_factory = PusherFactory(hs)
|
||||||
self._should_start_pushers = hs.config.start_pushers
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
@ -67,6 +68,16 @@ class PusherPool:
|
||||||
# We shard the handling of push notifications by user ID.
|
# We shard the handling of push notifications by user ID.
|
||||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
self._should_start_pushers = (
|
||||||
|
self._instance_name in self._pusher_shard_config.instances
|
||||||
|
)
|
||||||
|
|
||||||
|
# We can only delete pushers on master.
|
||||||
|
self._remove_pusher_client = None
|
||||||
|
if hs.config.worker.worker_app:
|
||||||
|
self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client(
|
||||||
|
hs
|
||||||
|
)
|
||||||
|
|
||||||
# Record the last stream ID that we were poked about so we can get
|
# Record the last stream ID that we were poked about so we can get
|
||||||
# changes since then. We set this to the current max stream ID on
|
# changes since then. We set this to the current max stream ID on
|
||||||
|
@ -103,6 +114,11 @@ class PusherPool:
|
||||||
The newly created pusher.
|
The newly created pusher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if kind == "email":
|
||||||
|
email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
|
||||||
|
if email_owner != user_id:
|
||||||
|
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||||
|
|
||||||
time_now_msec = self.clock.time_msec()
|
time_now_msec = self.clock.time_msec()
|
||||||
|
|
||||||
# create the pusher setting last_stream_ordering to the current maximum
|
# create the pusher setting last_stream_ordering to the current maximum
|
||||||
|
@ -175,9 +191,6 @@ class PusherPool:
|
||||||
user_id: user to remove pushers for
|
user_id: user to remove pushers for
|
||||||
access_tokens: access token *ids* to remove pushers for
|
access_tokens: access token *ids* to remove pushers for
|
||||||
"""
|
"""
|
||||||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
|
||||||
return
|
|
||||||
|
|
||||||
tokens = set(access_tokens)
|
tokens = set(access_tokens)
|
||||||
for p in await self.store.get_pushers_by_user_id(user_id):
|
for p in await self.store.get_pushers_by_user_id(user_id):
|
||||||
if p.access_token in tokens:
|
if p.access_token in tokens:
|
||||||
|
@ -380,6 +393,12 @@ class PusherPool:
|
||||||
|
|
||||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||||
|
|
||||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
# We can only delete pushers on master.
|
||||||
app_id, pushkey, user_id
|
if self._remove_pusher_client:
|
||||||
)
|
await self._remove_pusher_client(
|
||||||
|
app_id=app_id, pushkey=pushkey, user_id=user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||||
|
app_id, pushkey, user_id
|
||||||
|
)
|
||||||
|
|
|
@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = {
|
||||||
"pysaml2>=4.5.0;python_version>='3.6'",
|
"pysaml2>=4.5.0;python_version>='3.6'",
|
||||||
],
|
],
|
||||||
"oidc": ["authlib>=0.14.0"],
|
"oidc": ["authlib>=0.14.0"],
|
||||||
|
# systemd-python is necessary for logging to the systemd journal via
|
||||||
|
# `systemd.journal.JournalHandler`, as is documented in
|
||||||
|
# `contrib/systemd/log_config.yaml`.
|
||||||
"systemd": ["systemd-python>=231"],
|
"systemd": ["systemd-python>=231"],
|
||||||
"url_preview": ["lxml>=3.5.0"],
|
"url_preview": ["lxml>=3.5.0"],
|
||||||
"sentry": ["sentry-sdk>=0.7.2"],
|
"sentry": ["sentry-sdk>=0.7.2"],
|
||||||
|
|
|
@ -21,6 +21,7 @@ from synapse.replication.http import (
|
||||||
login,
|
login,
|
||||||
membership,
|
membership,
|
||||||
presence,
|
presence,
|
||||||
|
push,
|
||||||
register,
|
register,
|
||||||
send_event,
|
send_event,
|
||||||
streams,
|
streams,
|
||||||
|
@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource):
|
||||||
membership.register_servlets(hs, self)
|
membership.register_servlets(hs, self)
|
||||||
streams.register_servlets(hs, self)
|
streams.register_servlets(hs, self)
|
||||||
account_data.register_servlets(hs, self)
|
account_data.register_servlets(hs, self)
|
||||||
|
push.register_servlets(hs, self)
|
||||||
|
|
||||||
# The following can't currently be instantiated on workers.
|
# The following can't currently be instantiated on workers.
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
||||||
|
"""Deletes the given pusher.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
|
||||||
|
POST /_synapse/replication/remove_pusher/:user_id
|
||||||
|
|
||||||
|
{
|
||||||
|
"app_id": "<some_id>",
|
||||||
|
"pushkey": "<some_key>"
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "add_user_account_data"
|
||||||
|
PATH_ARGS = ("user_id",)
|
||||||
|
CACHE = False
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _serialize_payload(app_id, pushkey, user_id):
|
||||||
|
payload = {
|
||||||
|
"app_id": app_id,
|
||||||
|
"pushkey": pushkey,
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def _handle_request(self, request, user_id):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
app_id = content["app_id"]
|
||||||
|
pushkey = content["pushkey"]
|
||||||
|
|
||||||
|
await self.pusher_pool.remove_pusher(app_id, pushkey, user_id)
|
||||||
|
|
||||||
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReplicationRemovePusherRestServlet(hs).register(http_server)
|
|
@ -325,31 +325,6 @@ class FederationAckCommand(Command):
|
||||||
return "%s %s" % (self.instance_name, self.token)
|
return "%s %s" % (self.instance_name, self.token)
|
||||||
|
|
||||||
|
|
||||||
class RemovePusherCommand(Command):
|
|
||||||
"""Sent by the client to request the master remove the given pusher.
|
|
||||||
|
|
||||||
Format::
|
|
||||||
|
|
||||||
REMOVE_PUSHER <app_id> <push_key> <user_id>
|
|
||||||
"""
|
|
||||||
|
|
||||||
NAME = "REMOVE_PUSHER"
|
|
||||||
|
|
||||||
def __init__(self, app_id, push_key, user_id):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.app_id = app_id
|
|
||||||
self.push_key = push_key
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_line(cls, line):
|
|
||||||
app_id, push_key, user_id = line.split(" ", 2)
|
|
||||||
|
|
||||||
return cls(app_id, push_key, user_id)
|
|
||||||
|
|
||||||
def to_line(self):
|
|
||||||
return " ".join((self.app_id, self.push_key, self.user_id))
|
|
||||||
|
|
||||||
|
|
||||||
class UserIpCommand(Command):
|
class UserIpCommand(Command):
|
||||||
"""Sent periodically when a worker sees activity from a client.
|
"""Sent periodically when a worker sees activity from a client.
|
||||||
|
|
||||||
|
@ -416,7 +391,6 @@ _COMMANDS = (
|
||||||
ReplicateCommand,
|
ReplicateCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
FederationAckCommand,
|
FederationAckCommand,
|
||||||
RemovePusherCommand,
|
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
ClearUserSyncsCommand,
|
ClearUserSyncsCommand,
|
||||||
|
@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = (
|
||||||
UserSyncCommand.NAME,
|
UserSyncCommand.NAME,
|
||||||
ClearUserSyncsCommand.NAME,
|
ClearUserSyncsCommand.NAME,
|
||||||
FederationAckCommand.NAME,
|
FederationAckCommand.NAME,
|
||||||
RemovePusherCommand.NAME,
|
|
||||||
UserIpCommand.NAME,
|
UserIpCommand.NAME,
|
||||||
ErrorCommand.NAME,
|
ErrorCommand.NAME,
|
||||||
RemoteServerUpCommand.NAME,
|
RemoteServerUpCommand.NAME,
|
||||||
|
|
|
@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import (
|
||||||
PositionCommand,
|
PositionCommand,
|
||||||
RdataCommand,
|
RdataCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
RemovePusherCommand,
|
|
||||||
ReplicateCommand,
|
ReplicateCommand,
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
|
@ -373,23 +372,6 @@ class ReplicationCommandHandler:
|
||||||
if self._federation_sender:
|
if self._federation_sender:
|
||||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||||
|
|
||||||
def on_REMOVE_PUSHER(
|
|
||||||
self, conn: AbstractConnection, cmd: RemovePusherCommand
|
|
||||||
) -> Optional[Awaitable[None]]:
|
|
||||||
remove_pusher_counter.inc()
|
|
||||||
|
|
||||||
if self._is_master:
|
|
||||||
return self._handle_remove_pusher(cmd)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
|
|
||||||
await self._store.delete_pusher_by_app_id_pushkey_user_id(
|
|
||||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self._notifier.on_new_replication_data()
|
|
||||||
|
|
||||||
def on_USER_IP(
|
def on_USER_IP(
|
||||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
self, conn: AbstractConnection, cmd: UserIpCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
|
@ -684,11 +666,6 @@ class ReplicationCommandHandler:
|
||||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
|
||||||
"""Poke the master to remove a pusher for a user"""
|
|
||||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
|
||||||
self.send_command(cmd)
|
|
||||||
|
|
||||||
def send_user_ip(
|
def send_user_ip(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
|
@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
|
||||||
assert_user_is_admin,
|
assert_user_is_admin,
|
||||||
)
|
)
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
|
from synapse.storage.databases.main.media_repository import MediaSortOrder
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet):
|
||||||
errcode=Codes.INVALID_PARAM,
|
errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If neither `order_by` nor `dir` is set, set the default order
|
||||||
|
# to newest media is on top for backward compatibility.
|
||||||
|
if b"order_by" not in request.args and b"dir" not in request.args:
|
||||||
|
order_by = MediaSortOrder.CREATED_TS.value
|
||||||
|
direction = "b"
|
||||||
|
else:
|
||||||
|
order_by = parse_string(
|
||||||
|
request,
|
||||||
|
"order_by",
|
||||||
|
default=MediaSortOrder.CREATED_TS.value,
|
||||||
|
allowed_values=(
|
||||||
|
MediaSortOrder.MEDIA_ID.value,
|
||||||
|
MediaSortOrder.UPLOAD_NAME.value,
|
||||||
|
MediaSortOrder.CREATED_TS.value,
|
||||||
|
MediaSortOrder.LAST_ACCESS_TS.value,
|
||||||
|
MediaSortOrder.MEDIA_LENGTH.value,
|
||||||
|
MediaSortOrder.MEDIA_TYPE.value,
|
||||||
|
MediaSortOrder.QUARANTINED_BY.value,
|
||||||
|
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
direction = parse_string(
|
||||||
|
request, "dir", default="f", allowed_values=("f", "b")
|
||||||
|
)
|
||||||
|
|
||||||
media, total = await self.store.get_local_media_by_user_paginate(
|
media, total = await self.store.get_local_media_by_user_paginate(
|
||||||
start, limit, user_id
|
start, limit, user_id, order_by, direction
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = {"media": media, "total": total}
|
ret = {"media": media, "total": total}
|
||||||
|
|
|
@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.sso import SsoIdentityProvider
|
from synapse.handlers.sso import SsoIdentityProvider
|
||||||
|
from synapse.http import get_request_uri
|
||||||
from synapse.http.server import HttpServer, finish_request
|
from synapse.http.server import HttpServer, finish_request
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
|
@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
|
||||||
hs.get_oidc_handler()
|
hs.get_oidc_handler()
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
|
||||||
|
self._public_baseurl = hs.config.public_baseurl
|
||||||
|
|
||||||
def register(self, http_server: HttpServer) -> None:
|
def register(self, http_server: HttpServer) -> None:
|
||||||
super().register(http_server)
|
super().register(http_server)
|
||||||
|
@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: SynapseRequest, idp_id: Optional[str] = None
|
self, request: SynapseRequest, idp_id: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not self._public_baseurl:
|
||||||
|
raise SynapseError(400, "SSO requires a valid public_baseurl")
|
||||||
|
|
||||||
|
# if this isn't the expected hostname, redirect to the right one, so that we
|
||||||
|
# get our cookies back.
|
||||||
|
requested_uri = get_request_uri(request)
|
||||||
|
baseurl_bytes = self._public_baseurl.encode("utf-8")
|
||||||
|
if not requested_uri.startswith(baseurl_bytes):
|
||||||
|
# swap out the incorrect base URL for the right one.
|
||||||
|
#
|
||||||
|
# The idea here is to redirect from
|
||||||
|
# https://foo.bar/whatever/_matrix/...
|
||||||
|
# to
|
||||||
|
# https://public.baseurl/_matrix/...
|
||||||
|
#
|
||||||
|
i = requested_uri.index(b"/_matrix")
|
||||||
|
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
|
||||||
|
logger.info(
|
||||||
|
"Requested URI %s is not canonical: redirecting to %s",
|
||||||
|
requested_uri.decode("utf-8", errors="replace"),
|
||||||
|
new_uri.decode("utf-8", errors="replace"),
|
||||||
|
)
|
||||||
|
request.redirect(new_uri)
|
||||||
|
finish_request(request)
|
||||||
|
return
|
||||||
|
|
||||||
client_redirect_url = parse_string(
|
client_redirect_url = parse_string(
|
||||||
request, "redirectUrl", required=True, encoding=None
|
request, "redirectUrl", required=True, encoding=None
|
||||||
)
|
)
|
||||||
|
|
|
@ -54,7 +54,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
|
||||||
if hs.config.saml2_enabled:
|
if hs.config.saml2_enabled:
|
||||||
from synapse.rest.synapse.client.saml2 import SAML2Resource
|
from synapse.rest.synapse.client.saml2 import SAML2Resource
|
||||||
|
|
||||||
resources["/_synapse/client/saml2"] = SAML2Resource(hs)
|
res = SAML2Resource(hs)
|
||||||
|
resources["/_synapse/client/saml2"] = res
|
||||||
|
|
||||||
|
# This is also mounted under '/_matrix' for backwards-compatibility.
|
||||||
|
# To be removed in Synapse v1.32.0.
|
||||||
|
resources["/_matrix/saml2"] = res
|
||||||
|
|
||||||
return resources
|
return resources
|
||||||
|
|
||||||
|
|
|
@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
self.start_time = None # type: Optional[int]
|
self.start_time = None # type: Optional[int]
|
||||||
|
|
||||||
self._instance_id = random_string(5)
|
self._instance_id = random_string(5)
|
||||||
self._instance_name = config.worker_name or "master"
|
self._instance_name = config.worker.instance_name
|
||||||
|
|
||||||
self.version_string = version_string
|
self.version_string = version_string
|
||||||
|
|
||||||
|
@ -758,12 +758,6 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
reconnect=True,
|
reconnect=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
|
||||||
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
|
||||||
|
|
||||||
def should_send_federation(self) -> bool:
|
def should_send_federation(self) -> bool:
|
||||||
"Should this server be sending federation traffic directly?"
|
"Should this server be sending federation traffic directly?"
|
||||||
return self.config.send_federation and (
|
return self.config.send_federation
|
||||||
not self.config.worker_app
|
|
||||||
or self.config.worker_app == "synapse.app.federation_sender"
|
|
||||||
)
|
|
||||||
|
|
|
@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
|
|
||||||
# python 3 does not have a maximum int value
|
# python 3 does not have a maximum int value
|
||||||
|
@ -381,7 +380,10 @@ class DatabasePool:
|
||||||
_TXN_ID = 0
|
_TXN_ID = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
self,
|
||||||
|
hs,
|
||||||
|
database_config: DatabaseConnectionConfig,
|
||||||
|
engine: BaseDatabaseEngine,
|
||||||
):
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
@ -420,16 +422,6 @@ class DatabasePool:
|
||||||
self._check_safe_to_upsert,
|
self._check_safe_to_upsert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We define this sequence here so that it can be referenced from both
|
|
||||||
# the DataStore and PersistEventStore.
|
|
||||||
def get_chain_id_txn(txn):
|
|
||||||
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
|
||||||
return txn.fetchone()[0]
|
|
||||||
|
|
||||||
self.event_chain_id_gen = build_sequence_generator(
|
|
||||||
engine, get_chain_id_txn, "event_auth_chain_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Is the database pool currently running"""
|
"""Is the database pool currently running"""
|
||||||
return self._db_pool.running
|
return self._db_pool.running
|
||||||
|
|
|
@ -79,7 +79,7 @@ class Databases:
|
||||||
# If we're on a process that can persist events also
|
# If we're on a process that can persist events also
|
||||||
# instantiate a `PersistEventsStore`
|
# instantiate a `PersistEventsStore`
|
||||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||||
persist_events = PersistEventsStore(hs, database, main)
|
persist_events = PersistEventsStore(hs, database, main, db_conn)
|
||||||
|
|
||||||
if "state" in database_config.databases:
|
if "state" in database_config.databases:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -42,7 +42,9 @@ from synapse.logging.utils import log_function
|
||||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.search import SearchEntry
|
from synapse.storage.databases.main.search import SearchEntry
|
||||||
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.storage.util.sequence import SequenceGenerator
|
||||||
from synapse.types import StateMap, get_domain_from_id
|
from synapse.types import StateMap, get_domain_from_id
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.iterutils import batch_iter, sorted_topologically
|
from synapse.util.iterutils import batch_iter, sorted_topologically
|
||||||
|
@ -90,7 +92,11 @@ class PersistEventsStore:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
db: DatabasePool,
|
||||||
|
main_data_store: "DataStore",
|
||||||
|
db_conn: Connection,
|
||||||
):
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.db_pool = db
|
self.db_pool = db
|
||||||
|
@ -474,6 +480,7 @@ class PersistEventsStore:
|
||||||
self._add_chain_cover_index(
|
self._add_chain_cover_index(
|
||||||
txn,
|
txn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
self.store.event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
@ -484,6 +491,7 @@ class PersistEventsStore:
|
||||||
cls,
|
cls,
|
||||||
txn,
|
txn,
|
||||||
db_pool: DatabasePool,
|
db_pool: DatabasePool,
|
||||||
|
event_chain_id_gen: SequenceGenerator,
|
||||||
event_to_room_id: Dict[str, str],
|
event_to_room_id: Dict[str, str],
|
||||||
event_to_types: Dict[str, Tuple[str, str]],
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
event_to_auth_chain: Dict[str, List[str]],
|
event_to_auth_chain: Dict[str, List[str]],
|
||||||
|
@ -630,6 +638,7 @@ class PersistEventsStore:
|
||||||
new_chain_tuples = cls._allocate_chain_ids(
|
new_chain_tuples = cls._allocate_chain_ids(
|
||||||
txn,
|
txn,
|
||||||
db_pool,
|
db_pool,
|
||||||
|
event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
@ -768,6 +777,7 @@ class PersistEventsStore:
|
||||||
def _allocate_chain_ids(
|
def _allocate_chain_ids(
|
||||||
txn,
|
txn,
|
||||||
db_pool: DatabasePool,
|
db_pool: DatabasePool,
|
||||||
|
event_chain_id_gen: SequenceGenerator,
|
||||||
event_to_room_id: Dict[str, str],
|
event_to_room_id: Dict[str, str],
|
||||||
event_to_types: Dict[str, Tuple[str, str]],
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
event_to_auth_chain: Dict[str, List[str]],
|
event_to_auth_chain: Dict[str, List[str]],
|
||||||
|
@ -880,7 +890,7 @@ class PersistEventsStore:
|
||||||
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
||||||
|
|
||||||
# Generate new chain IDs for all unallocated chain IDs.
|
# Generate new chain IDs for all unallocated chain IDs.
|
||||||
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
|
newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
|
||||||
txn, len(unallocated_chain_ids)
|
txn, len(unallocated_chain_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
PersistEventsStore._add_chain_cover_index(
|
PersistEventsStore._add_chain_cover_index(
|
||||||
txn,
|
txn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
self.event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
|
|
@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import Collection, JsonDict, get_domain_from_id
|
from synapse.types import Collection, JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._event_fetch_list = []
|
self._event_fetch_list = []
|
||||||
self._event_fetch_ongoing = 0
|
self._event_fetch_ongoing = 0
|
||||||
|
|
||||||
|
# We define this sequence here so that it can be referenced from both
|
||||||
|
# the DataStore and PersistEventStore.
|
||||||
|
def get_chain_id_txn(txn):
|
||||||
|
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
||||||
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
|
self.event_chain_id_gen = build_sequence_generator(
|
||||||
|
db_conn,
|
||||||
|
database.engine,
|
||||||
|
get_chain_id_txn,
|
||||||
|
"event_auth_chain_id",
|
||||||
|
table="event_auth_chains",
|
||||||
|
id_column="chain_id",
|
||||||
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == EventsStream.NAME:
|
if stream_name == EventsStream.NAME:
|
||||||
self._stream_id_gen.advance(instance_name, token)
|
self._stream_id_gen.advance(instance_name, token)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MediaSortOrder(Enum):
|
||||||
|
"""
|
||||||
|
Enum to define the sorting method used when returning media with
|
||||||
|
get_local_media_by_user_paginate
|
||||||
|
"""
|
||||||
|
|
||||||
|
MEDIA_ID = "media_id"
|
||||||
|
UPLOAD_NAME = "upload_name"
|
||||||
|
CREATED_TS = "created_ts"
|
||||||
|
LAST_ACCESS_TS = "last_access_ts"
|
||||||
|
MEDIA_LENGTH = "media_length"
|
||||||
|
MEDIA_TYPE = "media_type"
|
||||||
|
QUARANTINED_BY = "quarantined_by"
|
||||||
|
SAFE_FROM_QUARANTINE = "safe_from_quarantine"
|
||||||
|
|
||||||
|
|
||||||
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_local_media_by_user_paginate(
|
async def get_local_media_by_user_paginate(
|
||||||
self, start: int, limit: int, user_id: str
|
self,
|
||||||
|
start: int,
|
||||||
|
limit: int,
|
||||||
|
user_id: str,
|
||||||
|
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
|
||||||
|
direction: str = "f",
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
"""Get a paginated list of metadata for a local piece of media
|
"""Get a paginated list of metadata for a local piece of media
|
||||||
which an user_id has uploaded
|
which an user_id has uploaded
|
||||||
|
@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
start: offset in the list
|
start: offset in the list
|
||||||
limit: maximum amount of media_ids to retrieve
|
limit: maximum amount of media_ids to retrieve
|
||||||
user_id: fully-qualified user id
|
user_id: fully-qualified user id
|
||||||
|
order_by: the sort order of the returned list
|
||||||
|
direction: sort ascending or descending
|
||||||
Returns:
|
Returns:
|
||||||
A paginated list of all metadata of user's media,
|
A paginated list of all metadata of user's media,
|
||||||
plus the total count of all the user's media
|
plus the total count of all the user's media
|
||||||
|
@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
|
|
||||||
def get_local_media_by_user_paginate_txn(txn):
|
def get_local_media_by_user_paginate_txn(txn):
|
||||||
|
|
||||||
|
# Set ordering
|
||||||
|
order_by_column = MediaSortOrder(order_by).value
|
||||||
|
|
||||||
|
if direction == "b":
|
||||||
|
order = "DESC"
|
||||||
|
else:
|
||||||
|
order = "ASC"
|
||||||
|
|
||||||
args = [user_id]
|
args = [user_id]
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(*) as total_media
|
SELECT COUNT(*) as total_media
|
||||||
|
@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
"safe_from_quarantine"
|
"safe_from_quarantine"
|
||||||
FROM local_media_repository
|
FROM local_media_repository
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
ORDER BY created_ts DESC, media_id DESC
|
ORDER BY {order_by_column} {order}, media_id ASC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
"""
|
""".format(
|
||||||
|
order_by_column=order_by_column,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
|
||||||
args += [limit, start]
|
args += [limit, start]
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
|
@ -373,3 +373,46 @@ class PusherStore(PusherWorkerStore):
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_pusher", delete_pusher_txn, stream_id
|
"delete_pusher", delete_pusher_txn, stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def delete_all_pushers_for_user(self, user_id: str) -> None:
|
||||||
|
"""Delete all pushers associated with an account."""
|
||||||
|
|
||||||
|
# We want to generate a row in `deleted_pushers` for each pusher we're
|
||||||
|
# deleting, so we fetch the list now so we can generate the appropriate
|
||||||
|
# number of stream IDs.
|
||||||
|
#
|
||||||
|
# Note: technically there could be a race here between adding/deleting
|
||||||
|
# pushers, but a) the worst case if we don't stop a pusher until the
|
||||||
|
# next restart and b) this is only called when we're deactivating an
|
||||||
|
# account.
|
||||||
|
pushers = list(await self.get_pushers_by_user_id(user_id))
|
||||||
|
|
||||||
|
def delete_pushers_txn(txn, stream_ids):
|
||||||
|
self._invalidate_cache_and_stream( # type: ignore
|
||||||
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_pool.simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="pushers",
|
||||||
|
keyvalues={"user_name": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_pool.simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="deleted_pushers",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"app_id": pusher.app_id,
|
||||||
|
"pushkey": pusher.pushkey,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
for stream_id, pusher in zip(stream_ids, pushers)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._pushers_id_gen.get_next_mult(len(pushers)) as stream_ids:
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"delete_all_pushers_for_user", delete_pushers_txn, stream_ids
|
||||||
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ import attr
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.stats import StatsStore
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
|
@ -70,7 +70,12 @@ class TokenLookupResult:
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
# call `find_max_generated_user_id_localpart` each time, which is
|
# call `find_max_generated_user_id_localpart` each time, which is
|
||||||
# expensive if there are many entries.
|
# expensive if there are many entries.
|
||||||
self._user_id_seq = build_sequence_generator(
|
self._user_id_seq = build_sequence_generator(
|
||||||
|
db_conn,
|
||||||
database.engine,
|
database.engine,
|
||||||
find_max_generated_user_id_localpart,
|
find_max_generated_user_id_localpart,
|
||||||
"user_id_seq",
|
"user_id_seq",
|
||||||
|
table=None,
|
||||||
|
id_column=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._account_validity = hs.config.account_validity
|
self._account_validity = hs.config.account_validity
|
||||||
|
@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- We may not have deleted all pushers for deactivated accounts. Do so now.
|
||||||
|
--
|
||||||
|
-- Note: We don't bother updating the `deleted_pushers` table as it's just use
|
||||||
|
-- to stop pushers on workers, and that will happen when they get next restarted.
|
||||||
|
DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);
|
|
@ -0,0 +1,19 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- Delete all pushers associated with deleted devices. This is to clear up after
|
||||||
|
-- a bug where they weren't correctly deleted when using workers.
|
||||||
|
DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
|
|
@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
async def add_users_in_public_rooms(
|
async def add_users_in_public_rooms(
|
||||||
self, room_id: str, user_ids: Iterable[str]
|
self, room_id: str, user_ids: Iterable[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
"""Insert entries into the users_in_public_rooms table.
|
||||||
user should be a local user.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id
|
room_id
|
||||||
|
@ -670,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||||
users.update(rows)
|
users.update(rows)
|
||||||
return list(users)
|
return list(users)
|
||||||
|
|
||||||
@cached()
|
|
||||||
async def get_shared_rooms_for_users(
|
async def get_shared_rooms_for_users(
|
||||||
self, user_id: str, other_user_id: str
|
self, user_id: str, other_user_id: str
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
|
|
|
@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
return txn.fetchone()[0]
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
self._state_group_seq_gen = build_sequence_generator(
|
self._state_group_seq_gen = build_sequence_generator(
|
||||||
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
|
db_conn,
|
||||||
)
|
self.database_engine,
|
||||||
self._state_group_seq_gen.check_consistency(
|
get_max_state_group_txn,
|
||||||
db_conn, table="state_groups", id_column="id"
|
"state_group_id_seq",
|
||||||
|
table="state_groups",
|
||||||
|
id_column="id",
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=10000, iterable=True)
|
@cached(max_entries=10000, iterable=True)
|
||||||
|
|
|
@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
|
|
||||||
|
|
||||||
def build_sequence_generator(
|
def build_sequence_generator(
|
||||||
|
db_conn: "LoggingDatabaseConnection",
|
||||||
database_engine: BaseDatabaseEngine,
|
database_engine: BaseDatabaseEngine,
|
||||||
get_first_callback: GetFirstCallbackType,
|
get_first_callback: GetFirstCallbackType,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
|
table: Optional[str],
|
||||||
|
id_column: Optional[str],
|
||||||
|
stream_name: Optional[str] = None,
|
||||||
|
positive: bool = True,
|
||||||
) -> SequenceGenerator:
|
) -> SequenceGenerator:
|
||||||
"""Get the best impl of SequenceGenerator available
|
"""Get the best impl of SequenceGenerator available
|
||||||
|
|
||||||
|
@ -265,8 +270,23 @@ def build_sequence_generator(
|
||||||
get_first_callback: a callback which gets the next sequence ID. Used if
|
get_first_callback: a callback which gets the next sequence ID. Used if
|
||||||
we're on sqlite.
|
we're on sqlite.
|
||||||
sequence_name: the name of a postgres sequence to use.
|
sequence_name: the name of a postgres sequence to use.
|
||||||
|
table, id_column, stream_name, positive: If set then `check_consistency`
|
||||||
|
is called on the created sequence. See docstring for
|
||||||
|
`check_consistency` details.
|
||||||
"""
|
"""
|
||||||
if isinstance(database_engine, PostgresEngine):
|
if isinstance(database_engine, PostgresEngine):
|
||||||
return PostgresSequenceGenerator(sequence_name)
|
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
|
||||||
else:
|
else:
|
||||||
return LocalSequenceGenerator(get_first_callback)
|
seq = LocalSequenceGenerator(get_first_callback)
|
||||||
|
|
||||||
|
if table:
|
||||||
|
assert id_column
|
||||||
|
seq.check_consistency(
|
||||||
|
db_conn=db_conn,
|
||||||
|
table=table,
|
||||||
|
id_column=id_column,
|
||||||
|
stream_name=stream_name,
|
||||||
|
positive=positive,
|
||||||
|
)
|
||||||
|
|
||||||
|
return seq
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ class ResponseCache(Generic[T]):
|
||||||
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
|
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
|
||||||
# Requests that haven't finished yet.
|
# Requests that haven't finished yet.
|
||||||
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
||||||
|
self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.timeout_sec = timeout_ms / 1000.0
|
self.timeout_sec = timeout_ms / 1000.0
|
||||||
|
@ -101,7 +102,11 @@ class ResponseCache(Generic[T]):
|
||||||
self.pending_result_cache[key] = result
|
self.pending_result_cache[key] = result
|
||||||
|
|
||||||
def remove(r):
|
def remove(r):
|
||||||
if self.timeout_sec:
|
should_cache = all(
|
||||||
|
func(r) for func in self.pending_conditionals.pop(key, [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.timeout_sec and should_cache:
|
||||||
self.clock.call_later(
|
self.clock.call_later(
|
||||||
self.timeout_sec, self.pending_result_cache.pop, key, None
|
self.timeout_sec, self.pending_result_cache.pop, key, None
|
||||||
)
|
)
|
||||||
|
@ -112,6 +117,31 @@ class ResponseCache(Generic[T]):
|
||||||
result.addBoth(remove)
|
result.addBoth(remove)
|
||||||
return result.observe()
|
return result.observe()
|
||||||
|
|
||||||
|
def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
|
||||||
|
self.pending_conditionals.setdefault(key, set()).add(conditional)
|
||||||
|
|
||||||
|
def wrap_conditional(
|
||||||
|
self,
|
||||||
|
key: T,
|
||||||
|
should_cache: Callable[[Any], bool],
|
||||||
|
callback: "Callable[..., Any]",
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> defer.Deferred:
|
||||||
|
"""The same as wrap(), but adds a conditional to the final execution.
|
||||||
|
|
||||||
|
When the final execution completes, *all* conditionals need to return True for it to properly cache,
|
||||||
|
else it'll not be cached in a timed fashion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
|
||||||
|
# python, adding a key immediately in the same execution thread will not cause a race condition.
|
||||||
|
result = self.get(key)
|
||||||
|
if not result or isinstance(result, defer.Deferred) and not result.called:
|
||||||
|
self.add_conditional(key, should_cache)
|
||||||
|
|
||||||
|
return self.wrap(key, callback, *args, **kwargs)
|
||||||
|
|
||||||
def wrap(
|
def wrap(
|
||||||
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
|
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
|
||||||
) -> defer.Deferred:
|
) -> defer.Deferred:
|
||||||
|
|
|
@ -21,6 +21,7 @@ import pkg_resources
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple.token_id
|
self.token_id = user_tuple.token_id
|
||||||
|
|
||||||
|
# We need to add email to account before we can create a pusher.
|
||||||
|
self.get_success(
|
||||||
|
hs.get_datastore().user_add_threepid(
|
||||||
|
self.user_id, "email", "a@example.com", 0, 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.pusher = self.get_success(
|
self.pusher = self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
access_token=token_id,
|
access_token=self.token_id,
|
||||||
kind="email",
|
kind="email",
|
||||||
app_id="m.email",
|
app_id="m.email",
|
||||||
app_display_name="Email Notifications",
|
app_display_name="Email Notifications",
|
||||||
|
@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_need_validated_email(self):
|
||||||
|
"""Test that we can only add an email pusher if the user has validated
|
||||||
|
their email.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(SynapseError) as cm:
|
||||||
|
self.get_success_or_raise(
|
||||||
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
user_id=self.user_id,
|
||||||
|
access_token=self.token_id,
|
||||||
|
kind="email",
|
||||||
|
app_id="m.email",
|
||||||
|
app_display_name="Email Notifications",
|
||||||
|
device_display_name="b@example.com",
|
||||||
|
pushkey="b@example.com",
|
||||||
|
lang=None,
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, cm.exception.code)
|
||||||
|
self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
|
||||||
|
|
||||||
def test_simple_sends_email(self):
|
def test_simple_sends_email(self):
|
||||||
# Create a simple room with two users
|
# Create a simple room with two users
|
||||||
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||||
|
|
|
@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
|
||||||
# enable federation sending on the worker
|
# enable federation sending on the worker
|
||||||
config = super()._get_worker_hs_config()
|
config = super()._get_worker_hs_config()
|
||||||
# TODO: make it so we don't need both of these
|
# TODO: make it so we don't need both of these
|
||||||
config["send_federation"] = True
|
config["send_federation"] = False
|
||||||
config["worker_app"] = "synapse.app.federation_sender"
|
config["worker_app"] = "synapse.app.federation_sender"
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase):
|
||||||
def default_config(self) -> dict:
|
def default_config(self) -> dict:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["worker_app"] = "synapse.app.federation_sender"
|
config["worker_app"] = "synapse.app.federation_sender"
|
||||||
config["send_federation"] = True
|
config["send_federation"] = False
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
|
@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
self.make_worker_hs(
|
self.make_worker_hs(
|
||||||
"synapse.app.federation_sender",
|
"synapse.app.federation_sender",
|
||||||
{"send_federation": True},
|
{"send_federation": False},
|
||||||
federation_http_client=mock_client,
|
federation_http_client=mock_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
self.make_worker_hs(
|
self.make_worker_hs(
|
||||||
"synapse.app.pusher",
|
"synapse.app.pusher",
|
||||||
{"start_pushers": True},
|
{"start_pushers": False},
|
||||||
proxied_blacklisted_http_client=http_client_mock,
|
proxied_blacklisted_http_client=http_client_mock,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ import hmac
|
||||||
import json
|
import json
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.server import FakeSite, make_request
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
self.media_repo = hs.get_media_repository_resource()
|
self.media_repo = hs.get_media_repository_resource()
|
||||||
|
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
number_media = 20
|
number_media = 20
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
self._create_media(other_user_tok, number_media)
|
self._create_media_for_user(other_user_tok, number_media)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
number_media = 20
|
number_media = 20
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
self._create_media(other_user_tok, number_media)
|
self._create_media_for_user(other_user_tok, number_media)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
number_media = 20
|
number_media = 20
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
self._create_media(other_user_tok, number_media)
|
self._create_media_for_user(other_user_tok, number_media)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(channel.json_body["media"]), 10)
|
self.assertEqual(len(channel.json_body["media"]), 10)
|
||||||
self._check_fields(channel.json_body["media"])
|
self._check_fields(channel.json_body["media"])
|
||||||
|
|
||||||
def test_limit_is_negative(self):
|
def test_invalid_parameter(self):
|
||||||
"""
|
"""
|
||||||
Testing that a negative limit parameter returns a 400
|
If parameters are invalid, an error is returned.
|
||||||
"""
|
"""
|
||||||
|
# unkown order_by
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?order_by=bar",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# invalid search order
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?dir=bar",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# negative limit
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
self.url + "?limit=-5",
|
self.url + "?limit=-5",
|
||||||
|
@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||||
|
|
||||||
def test_from_is_negative(self):
|
# negative from
|
||||||
"""
|
|
||||||
Testing that a negative from parameter returns a 400
|
|
||||||
"""
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
self.url + "?from=-5",
|
self.url + "?from=-5",
|
||||||
|
@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
number_media = 20
|
number_media = 20
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
self._create_media(other_user_tok, number_media)
|
self._create_media_for_user(other_user_tok, number_media)
|
||||||
|
|
||||||
# `next_token` does not appear
|
# `next_token` does not appear
|
||||||
# Number of results is the number of entries
|
# Number of results is the number of entries
|
||||||
|
@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
number_media = 5
|
number_media = 5
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
self._create_media(other_user_tok, number_media)
|
self._create_media_for_user(other_user_tok, number_media)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotIn("next_token", channel.json_body)
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
self._check_fields(channel.json_body["media"])
|
self._check_fields(channel.json_body["media"])
|
||||||
|
|
||||||
def _create_media(self, user_token, number_media):
|
def test_order_by(self):
|
||||||
|
"""
|
||||||
|
Testing order list with parameter `order_by`
|
||||||
|
"""
|
||||||
|
|
||||||
|
other_user_tok = self.login("user", "pass")
|
||||||
|
|
||||||
|
# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
|
||||||
|
image_data1 = unhexlify(
|
||||||
|
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
|
||||||
|
b"0000001f15c4890000000a49444154789c63000100000500010d"
|
||||||
|
b"0a2db40000000049454e44ae426082"
|
||||||
|
)
|
||||||
|
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
|
||||||
|
image_data2 = unhexlify(
|
||||||
|
b"47494638376101000100800100000000"
|
||||||
|
b"ffffff2c00000000010001000002024c"
|
||||||
|
b"01003b"
|
||||||
|
)
|
||||||
|
# Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
|
||||||
|
image_data3 = unhexlify(
|
||||||
|
b"424d3a0000000000000036000000280000000100000001000000"
|
||||||
|
b"0100180000000000040000000000000000000000000000000000"
|
||||||
|
b"0000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# create media and make sure they do not have the same timestamp
|
||||||
|
media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png")
|
||||||
|
self.pump(1.0)
|
||||||
|
media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif")
|
||||||
|
self.pump(1.0)
|
||||||
|
media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp")
|
||||||
|
self.pump(1.0)
|
||||||
|
|
||||||
|
# Mark one media as safe from quarantine.
|
||||||
|
self.get_success(self.store.mark_local_media_as_safe(media2))
|
||||||
|
# Quarantine one media
|
||||||
|
self.get_success(
|
||||||
|
self.store.quarantine_media_by_id("test", media3, self.admin_user)
|
||||||
|
)
|
||||||
|
|
||||||
|
# order by default ("created_ts")
|
||||||
|
# default is backwards
|
||||||
|
self._order_test([media3, media2, media1], None)
|
||||||
|
self._order_test([media1, media2, media3], None, "f")
|
||||||
|
self._order_test([media3, media2, media1], None, "b")
|
||||||
|
|
||||||
|
# sort by media_id
|
||||||
|
sorted_media = sorted([media1, media2, media3], reverse=False)
|
||||||
|
sorted_media_reverse = sorted(sorted_media, reverse=True)
|
||||||
|
|
||||||
|
# order by media_id
|
||||||
|
self._order_test(sorted_media, "media_id")
|
||||||
|
self._order_test(sorted_media, "media_id", "f")
|
||||||
|
self._order_test(sorted_media_reverse, "media_id", "b")
|
||||||
|
|
||||||
|
# order by upload_name
|
||||||
|
self._order_test([media3, media2, media1], "upload_name")
|
||||||
|
self._order_test([media3, media2, media1], "upload_name", "f")
|
||||||
|
self._order_test([media1, media2, media3], "upload_name", "b")
|
||||||
|
|
||||||
|
# order by media_type
|
||||||
|
# result is ordered by media_id
|
||||||
|
# because of uploaded media_type is always 'application/json'
|
||||||
|
self._order_test(sorted_media, "media_type")
|
||||||
|
self._order_test(sorted_media, "media_type", "f")
|
||||||
|
self._order_test(sorted_media, "media_type", "b")
|
||||||
|
|
||||||
|
# order by media_length
|
||||||
|
self._order_test([media2, media3, media1], "media_length")
|
||||||
|
self._order_test([media2, media3, media1], "media_length", "f")
|
||||||
|
self._order_test([media1, media3, media2], "media_length", "b")
|
||||||
|
|
||||||
|
# order by created_ts
|
||||||
|
self._order_test([media1, media2, media3], "created_ts")
|
||||||
|
self._order_test([media1, media2, media3], "created_ts", "f")
|
||||||
|
self._order_test([media3, media2, media1], "created_ts", "b")
|
||||||
|
|
||||||
|
# order by last_access_ts
|
||||||
|
self._order_test([media1, media2, media3], "last_access_ts")
|
||||||
|
self._order_test([media1, media2, media3], "last_access_ts", "f")
|
||||||
|
self._order_test([media3, media2, media1], "last_access_ts", "b")
|
||||||
|
|
||||||
|
# order by quarantined_by
|
||||||
|
# one media is in quarantine, others are ordered by media_ids
|
||||||
|
|
||||||
|
# Different sort order of SQlite and PostreSQL
|
||||||
|
# If a media is not in quarantine `quarantined_by` is NULL
|
||||||
|
# SQLite considers NULL to be smaller than any other value.
|
||||||
|
# PostreSQL considers NULL to be larger than any other value.
|
||||||
|
|
||||||
|
# self._order_test(sorted([media1, media2]) + [media3], "quarantined_by")
|
||||||
|
# self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f")
|
||||||
|
# self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b")
|
||||||
|
|
||||||
|
# order by safe_from_quarantine
|
||||||
|
# one media is safe from quarantine, others are ordered by media_ids
|
||||||
|
self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine")
|
||||||
|
self._order_test(
|
||||||
|
sorted([media1, media3]) + [media2], "safe_from_quarantine", "f"
|
||||||
|
)
|
||||||
|
self._order_test(
|
||||||
|
[media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_media_for_user(self, user_token: str, number_media: int):
|
||||||
"""
|
"""
|
||||||
Create a number of media for a specific user
|
Create a number of media for a specific user
|
||||||
|
Args:
|
||||||
|
user_token: Access token of the user
|
||||||
|
number_media: Number of media to be created for the user
|
||||||
"""
|
"""
|
||||||
upload_resource = self.media_repo.children[b"upload"]
|
|
||||||
for i in range(number_media):
|
for i in range(number_media):
|
||||||
# file size is 67 Byte
|
# file size is 67 Byte
|
||||||
image_data = unhexlify(
|
image_data = unhexlify(
|
||||||
|
@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
b"0a2db40000000049454e44ae426082"
|
b"0a2db40000000049454e44ae426082"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Upload some media into the room
|
self._create_media_and_access(user_token, image_data)
|
||||||
self.helper.upload_media(
|
|
||||||
upload_resource, image_data, tok=user_token, expect_code=200
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_fields(self, content):
|
def _create_media_and_access(
|
||||||
"""Checks that all attributes are present in content"""
|
self,
|
||||||
|
user_token: str,
|
||||||
|
image_data: bytes,
|
||||||
|
filename: str = "image1.png",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create one media for a specific user, access and returns `media_id`
|
||||||
|
Args:
|
||||||
|
user_token: Access token of the user
|
||||||
|
image_data: binary data of image
|
||||||
|
filename: The filename of the media to be uploaded
|
||||||
|
Returns:
|
||||||
|
The ID of the newly created media.
|
||||||
|
"""
|
||||||
|
upload_resource = self.media_repo.children[b"upload"]
|
||||||
|
download_resource = self.media_repo.children[b"download"]
|
||||||
|
|
||||||
|
# Upload some media into the room
|
||||||
|
response = self.helper.upload_media(
|
||||||
|
upload_resource, image_data, user_token, filename, expect_code=200
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract media ID from the response
|
||||||
|
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
|
||||||
|
media_id = server_and_media_id.split("/")[1]
|
||||||
|
|
||||||
|
# Try to access a media and to create `last_access_ts`
|
||||||
|
channel = make_request(
|
||||||
|
self.reactor,
|
||||||
|
FakeSite(download_resource),
|
||||||
|
"GET",
|
||||||
|
server_and_media_id,
|
||||||
|
shorthand=False,
|
||||||
|
access_token=user_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
200,
|
||||||
|
channel.code,
|
||||||
|
msg=(
|
||||||
|
"Expected to receive a 200 on accessing media: %s" % server_and_media_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return media_id
|
||||||
|
|
||||||
|
def _check_fields(self, content: JsonDict):
|
||||||
|
"""Checks that the expected user attributes are present in content
|
||||||
|
Args:
|
||||||
|
content: List that is checked for content
|
||||||
|
"""
|
||||||
for m in content:
|
for m in content:
|
||||||
self.assertIn("media_id", m)
|
self.assertIn("media_id", m)
|
||||||
self.assertIn("media_type", m)
|
self.assertIn("media_type", m)
|
||||||
|
@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertIn("quarantined_by", m)
|
self.assertIn("quarantined_by", m)
|
||||||
self.assertIn("safe_from_quarantine", m)
|
self.assertIn("safe_from_quarantine", m)
|
||||||
|
|
||||||
|
def _order_test(
|
||||||
|
self,
|
||||||
|
expected_media_list: List[str],
|
||||||
|
order_by: Optional[str],
|
||||||
|
dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Request the list of media in a certain order. Assert that order is what
|
||||||
|
we expect
|
||||||
|
Args:
|
||||||
|
expected_media_list: The list of media_ids in the order we expect to get
|
||||||
|
back from the server
|
||||||
|
order_by: The type of ordering to give the server
|
||||||
|
dir: The direction of ordering to give the server
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = self.url + "?"
|
||||||
|
if order_by is not None:
|
||||||
|
url += "order_by=%s&" % (order_by,)
|
||||||
|
if dir is not None and dir in ("b", "f"):
|
||||||
|
url += "dir=%s" % (dir,)
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
url.encode("ascii"),
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(channel.json_body["total"], len(expected_media_list))
|
||||||
|
|
||||||
|
returned_order = [row["media_id"] for row in channel.json_body["media"]]
|
||||||
|
self.assertEqual(expected_media_list, returned_order)
|
||||||
|
self._check_fields(channel.json_body["media"])
|
||||||
|
|
||||||
|
|
||||||
class UserTokenRestTestCase(unittest.HomeserverTestCase):
|
class UserTokenRestTestCase(unittest.HomeserverTestCase):
|
||||||
"""Test for /_synapse/admin/v1/users/<user>/login"""
|
"""Test for /_synapse/admin/v1/users/<user>/login"""
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
@ -47,8 +47,14 @@ except ImportError:
|
||||||
HAS_JWT = False
|
HAS_JWT = False
|
||||||
|
|
||||||
|
|
||||||
# public_base_url used in some tests
|
# synapse server name: used to populate public_baseurl in some tests
|
||||||
BASE_URL = "https://synapse/"
|
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
|
||||||
|
|
||||||
|
# public_baseurl for some tests. It uses an http:// scheme because
|
||||||
|
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
|
||||||
|
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
|
||||||
|
# https://....
|
||||||
|
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
|
||||||
|
|
||||||
# CAS server used in some tests
|
# CAS server used in some tests
|
||||||
CAS_SERVER = "https://fake.test"
|
CAS_SERVER = "https://fake.test"
|
||||||
|
@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
def test_multi_sso_redirect(self):
|
def test_multi_sso_redirect(self):
|
||||||
"""/login/sso/redirect should redirect to an identity picker"""
|
"""/login/sso/redirect should redirect to an identity picker"""
|
||||||
# first hit the redirect url, which should redirect to our idp picker
|
# first hit the redirect url, which should redirect to our idp picker
|
||||||
channel = self.make_request(
|
channel = self._make_sso_redirect_request(False, None)
|
||||||
"GET",
|
|
||||||
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
|
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
uri = channel.headers.getRawHeaders("Location")[0]
|
uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
|
||||||
|
@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_client_idp_redirect_msc2858_disabled(self):
|
def test_client_idp_redirect_msc2858_disabled(self):
|
||||||
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
|
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
|
||||||
channel = self.make_request(
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
"GET",
|
|
||||||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
|
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
def test_client_idp_redirect_to_unknown(self):
|
def test_client_idp_redirect_to_unknown(self):
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
channel = self.make_request(
|
channel = self._make_sso_redirect_request(True, "xxx")
|
||||||
"GET",
|
|
||||||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
|
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 404, channel.result)
|
self.assertEqual(channel.code, 404, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
@override_config({"experimental_features": {"msc2858_enabled": True}})
|
||||||
def test_client_idp_redirect_to_oidc(self):
|
def test_client_idp_redirect_to_oidc(self):
|
||||||
"""If the client pick a known IdP, redirect to it"""
|
"""If the client pick a known IdP, redirect to it"""
|
||||||
channel = self.make_request(
|
channel = self._make_sso_redirect_request(True, "oidc")
|
||||||
"GET",
|
|
||||||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
|
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
# it should redirect us to the auth page of the OIDC server
|
# it should redirect us to the auth page of the OIDC server
|
||||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
|
def _make_sso_redirect_request(
|
||||||
|
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Send a request to /_matrix/client/r0/login/sso/redirect
|
||||||
|
|
||||||
|
... or the unstable equivalent
|
||||||
|
|
||||||
|
... possibly specifying an IDP provider
|
||||||
|
"""
|
||||||
|
endpoint = (
|
||||||
|
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
|
||||||
|
if unstable_endpoint
|
||||||
|
else "/_matrix/client/r0/login/sso/redirect"
|
||||||
|
)
|
||||||
|
if idp_prov is not None:
|
||||||
|
endpoint += "/" + idp_prov
|
||||||
|
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||||
|
|
||||||
|
return self.make_request(
|
||||||
|
"GET",
|
||||||
|
endpoint,
|
||||||
|
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
prefix = key + " = "
|
prefix = key + " = "
|
||||||
|
|
|
@ -542,13 +542,30 @@ class RestHelper:
|
||||||
if client_redirect_url:
|
if client_redirect_url:
|
||||||
params["redirectUrl"] = client_redirect_url
|
params["redirectUrl"] = client_redirect_url
|
||||||
|
|
||||||
# hit the redirect url (which will issue a cookie and state)
|
# hit the redirect url (which should redirect back to the redirect url. This
|
||||||
|
# is the easiest way of figuring out what the Host header ought to be set to
|
||||||
|
# to keep Synapse happy.
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
|
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
|
||||||
)
|
)
|
||||||
|
assert channel.code == 302
|
||||||
|
|
||||||
|
# hit the redirect url again with the right Host header, which should now issue
|
||||||
|
# a cookie and redirect to the SSO provider.
|
||||||
|
location = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
parts = urllib.parse.urlsplit(location)
|
||||||
|
channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
urllib.parse.urlunsplit(("", "") + parts[2:]),
|
||||||
|
custom_headers=[
|
||||||
|
("Host", parts[1]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
assert channel.code == 302
|
assert channel.code == 302
|
||||||
channel.extract_cookies(cookies)
|
channel.extract_cookies(cookies)
|
||||||
|
|
|
@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self):
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = "https://synapse.test"
|
|
||||||
|
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
|
||||||
|
# False, so synapse will see the requested uri as http://..., so using http in
|
||||||
|
# the public_baseurl stops Synapse trying to redirect to https.
|
||||||
|
config["public_baseurl"] = "http://synapse.test"
|
||||||
|
|
||||||
if HAS_OIDC:
|
if HAS_OIDC:
|
||||||
# we enable OIDC as a way of testing SSO flows
|
# we enable OIDC as a way of testing SSO flows
|
||||||
|
|
|
@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
A room should show up in the shared list of rooms between two users
|
A room should show up in the shared list of rooms between two users
|
||||||
if it is public.
|
if it is public.
|
||||||
"""
|
"""
|
||||||
u1 = self.register_user("user1", "pass")
|
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
|
||||||
u1_token = self.login(u1, "pass")
|
|
||||||
u2 = self.register_user("user2", "pass")
|
|
||||||
u2_token = self.login(u2, "pass")
|
|
||||||
|
|
||||||
room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
|
|
||||||
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
|
|
||||||
self.helper.join(room, user=u2, tok=u2_token)
|
|
||||||
|
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
|
||||||
self.assertEquals(len(channel.json_body["joined"]), 1)
|
|
||||||
self.assertEquals(channel.json_body["joined"][0], room)
|
|
||||||
|
|
||||||
def test_shared_room_list_private(self):
|
def test_shared_room_list_private(self):
|
||||||
"""
|
"""
|
||||||
A room should show up in the shared list of rooms between two users
|
A room should show up in the shared list of rooms between two users
|
||||||
if it is private.
|
if it is private.
|
||||||
"""
|
"""
|
||||||
u1 = self.register_user("user1", "pass")
|
self._check_shared_rooms_with(
|
||||||
u1_token = self.login(u1, "pass")
|
room_one_is_public=False, room_two_is_public=False
|
||||||
u2 = self.register_user("user2", "pass")
|
)
|
||||||
u2_token = self.login(u2, "pass")
|
|
||||||
|
|
||||||
room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
|
|
||||||
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
|
|
||||||
self.helper.join(room, user=u2, tok=u2_token)
|
|
||||||
|
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
|
||||||
self.assertEquals(len(channel.json_body["joined"]), 1)
|
|
||||||
self.assertEquals(channel.json_body["joined"][0], room)
|
|
||||||
|
|
||||||
def test_shared_room_list_mixed(self):
|
def test_shared_room_list_mixed(self):
|
||||||
"""
|
"""
|
||||||
The shared room list between two users should contain both public and private
|
The shared room list between two users should contain both public and private
|
||||||
rooms.
|
rooms.
|
||||||
"""
|
"""
|
||||||
|
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
|
||||||
|
|
||||||
|
def _check_shared_rooms_with(
|
||||||
|
self, room_one_is_public: bool, room_two_is_public: bool
|
||||||
|
):
|
||||||
|
"""Checks that shared public or private rooms between two users appear in
|
||||||
|
their shared room lists
|
||||||
|
"""
|
||||||
u1 = self.register_user("user1", "pass")
|
u1 = self.register_user("user1", "pass")
|
||||||
u1_token = self.login(u1, "pass")
|
u1_token = self.login(u1, "pass")
|
||||||
u2 = self.register_user("user2", "pass")
|
u2 = self.register_user("user2", "pass")
|
||||||
u2_token = self.login(u2, "pass")
|
u2_token = self.login(u2, "pass")
|
||||||
|
|
||||||
room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
|
# Create a room. user1 invites user2, who joins
|
||||||
room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
|
room_id_one = self.helper.create_room_as(
|
||||||
self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
|
u1, is_public=room_one_is_public, tok=u1_token
|
||||||
self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
|
)
|
||||||
self.helper.join(room_public, user=u2, tok=u2_token)
|
self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
|
||||||
self.helper.join(room_private, user=u1, tok=u1_token)
|
self.helper.join(room_id_one, user=u2, tok=u2_token)
|
||||||
|
|
||||||
|
# Check shared rooms from user1's perspective.
|
||||||
|
# We should see the one room in common
|
||||||
|
channel = self._get_shared_rooms(u1_token, u2)
|
||||||
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
self.assertEquals(len(channel.json_body["joined"]), 1)
|
||||||
|
self.assertEquals(channel.json_body["joined"][0], room_id_one)
|
||||||
|
|
||||||
|
# Create another room and invite user2 to it
|
||||||
|
room_id_two = self.helper.create_room_as(
|
||||||
|
u1, is_public=room_two_is_public, tok=u1_token
|
||||||
|
)
|
||||||
|
self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
|
||||||
|
self.helper.join(room_id_two, user=u2, tok=u2_token)
|
||||||
|
|
||||||
|
# Check shared rooms again. We should now see both rooms.
|
||||||
channel = self._get_shared_rooms(u1_token, u2)
|
channel = self._get_shared_rooms(u1_token, u2)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
self.assertEquals(len(channel.json_body["joined"]), 2)
|
self.assertEquals(len(channel.json_body["joined"]), 2)
|
||||||
self.assertTrue(room_public in channel.json_body["joined"])
|
for room_id_id in channel.json_body["joined"]:
|
||||||
self.assertTrue(room_private in channel.json_body["joined"])
|
self.assertIn(room_id_id, [room_id_one, room_id_two])
|
||||||
|
|
||||||
def test_shared_room_list_after_leave(self):
|
def test_shared_room_list_after_leave(self):
|
||||||
"""
|
"""
|
||||||
|
@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.helper.leave(room, user=u1, tok=u1_token)
|
self.helper.leave(room, user=u1, tok=u1_token)
|
||||||
|
|
||||||
|
# Check user1's view of shared rooms with user2
|
||||||
|
channel = self._get_shared_rooms(u1_token, u2)
|
||||||
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
self.assertEquals(len(channel.json_body["joined"]), 0)
|
||||||
|
|
||||||
|
# Check user2's view of shared rooms with user1
|
||||||
channel = self._get_shared_rooms(u2_token, u1)
|
channel = self._get_shared_rooms(u2_token, u1)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
self.assertEquals(len(channel.json_body["joined"]), 0)
|
self.assertEquals(len(channel.json_body["joined"]), 0)
|
||||||
|
|
|
@ -124,7 +124,11 @@ class FakeChannel:
|
||||||
return address.IPv4Address("TCP", self._ip, 3423)
|
return address.IPv4Address("TCP", self._ip, 3423)
|
||||||
|
|
||||||
def getHost(self):
|
def getHost(self):
|
||||||
return None
|
# this is called by Request.__init__ to configure Request.host.
|
||||||
|
return address.IPv4Address("TCP", "127.0.0.1", 8888)
|
||||||
|
|
||||||
|
def isSecure(self):
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transport(self):
|
def transport(self):
|
||||||
|
|
|
@ -114,7 +114,6 @@ def default_config(name, parse=False):
|
||||||
"server_name": name,
|
"server_name": name,
|
||||||
"send_federation": False,
|
"send_federation": False,
|
||||||
"media_store_path": "media",
|
"media_store_path": "media",
|
||||||
"uploads_path": "uploads",
|
|
||||||
# the test signing key is just an arbitrary ed25519 key to keep the config
|
# the test signing key is just an arbitrary ed25519 key to keep the config
|
||||||
# parser happy
|
# parser happy
|
||||||
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
|
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
|
||||||
|
|
Loading…
Reference in New Issue