Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

michaelkaye/remove_warning
Richard van der Hoff 2021-02-26 14:05:40 +00:00
commit fdbccc1e74
78 changed files with 1161 additions and 315 deletions

View File

@ -1,10 +1,26 @@
Synapse 1.28.0rc1 (2021-02-19) Synapse 1.xx.0
============================== ==============
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
Synapse 1.28.0 (2021-02-25)
===========================
Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails). Note that this release drops support for ARMv7 in the official Docker images, due to repeated problems building for ARMv7 (and the associated maintenance burden this entails).
This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes. This release also fixes the documentation included in v1.27.0 around the callback URI for SAML2 identity providers. If your server is configured to use single sign-on via a SAML2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
Internal Changes
----------------
- Revert change in v1.28.0rc1 to remove the deprecated SAML endpoint. ([\#9474](https://github.com/matrix-org/synapse/issues/9474))
Synapse 1.28.0rc1 (2021-02-19)
==============================
Removal warning Removal warning
--------------- ---------------
@ -31,7 +47,7 @@ Bugfixes
-------- --------
- Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257)) - Fix long-standing bug where sending email notifications would fail for rooms that the server had since left. ([\#9257](https://github.com/matrix-org/synapse/issues/9257))
- Fix bug in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296)) - Fix bug introduced in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. ([\#9296](https://github.com/matrix-org/synapse/issues/9296))
- Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393)) - Assert a maximum length for some parameters for spec compliance. ([\#9321](https://github.com/matrix-org/synapse/issues/9321), [\#9393](https://github.com/matrix-org/synapse/issues/9393))
- Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333)) - Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". ([\#9333](https://github.com/matrix-org/synapse/issues/9333))
- Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361)) - Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. ([\#9361](https://github.com/matrix-org/synapse/issues/9361))

View File

@ -85,6 +85,26 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.29.0
====================
Requirement for X-Forwarded-Proto header
----------------------------------------
When using Synapse with a reverse proxy (in particular, when using the
`x_forwarded` option on an HTTP listener), Synapse now expects to receive an
`X-Forwarded-Proto` header on incoming HTTP requests. If it is not set, Synapse
will log a warning on each received request.
To avoid the warning, administrators using a reverse proxy should ensure that
the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
indicate the protocol used by the client. See the [reverse proxy
documentation](docs/reverse_proxy.md), where the example configurations have
been updated to show how to set this header.
(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
sets `X-Forwarded-Proto` by default.)
Upgrading to v1.27.0 Upgrading to v1.27.0
==================== ====================

1
changelog.d/8978.feature Normal file
View File

@ -0,0 +1 @@
Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel.

1
changelog.d/9285.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where users' pushers were not all deleted when they deactivated their account.

1
changelog.d/9358.misc Normal file
View File

@ -0,0 +1 @@
Added a fix that invalidates cache for empty timed-out sync responses.

1
changelog.d/9416.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results.

1
changelog.d/9436.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug in single sign-on which could cause a "No session cookie found" error.

1
changelog.d/9449.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`.

1
changelog.d/9462.misc Normal file
View File

@ -0,0 +1 @@
Remove vestiges of `uploads_path` configuration setting.

1
changelog.d/9463.doc Normal file
View File

@ -0,0 +1 @@
Update the example systemd config to propagate reloads to individual units.

1
changelog.d/9464.misc Normal file
View File

@ -0,0 +1 @@
Add a comment about systemd-python.

1
changelog.d/9465.bugfix Normal file
View File

@ -0,0 +1 @@
Fix deleting pushers when using sharded pushers.

1
changelog.d/9466.bugfix Normal file
View File

@ -0,0 +1 @@
Fix deleting pushers when using sharded pushers.

1
changelog.d/9470.bugfix Normal file
View File

@ -0,0 +1 @@
Fix missing startup checks for the consistency of certain PostgreSQL sequences.

1
changelog.d/9472.feature Normal file
View File

@ -0,0 +1 @@
Add support for `X-Forwarded-Proto` header when using a reverse proxy.

1
changelog.d/9479.bugfix Normal file
View File

@ -0,0 +1 @@
Fix deleting pushers when using sharded pushers.

1
changelog.d/9496.misc Normal file
View File

@ -0,0 +1 @@
Test that we require validated email for email pushers.

1
changelog.d/9501.feature Normal file
View File

@ -0,0 +1 @@
Add support for `X-Forwarded-Proto` header when using a reverse proxy.

6
debian/changelog vendored
View File

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.28.0) stable; urgency=medium
* New synapse release 1.28.0.
-- Synapse Packaging team <packages@matrix.org> Thu, 25 Feb 2021 10:21:57 +0000
matrix-synapse-py3 (1.27.0) stable; urgency=medium matrix-synapse-py3 (1.27.0) stable; urgency=medium
[ Dan Callahan ] [ Dan Callahan ]

View File

@ -11,7 +11,6 @@ The image also does *not* provide a TURN server.
By default, the image expects a single volume, located at ``/data``, that will hold: By default, the image expects a single volume, located at ``/data``, that will hold:
* configuration files; * configuration files;
* temporary files during uploads;
* uploaded media and thumbnails; * uploaded media and thumbnails;
* the SQLite database if you do not configure postgres; * the SQLite database if you do not configure postgres;
* the appservices configuration. * the appservices configuration.

View File

@ -89,7 +89,6 @@ federation_rc_concurrent: 3
## Files ## ## Files ##
media_store_path: "/data/media" media_store_path: "/data/media"
uploads_path: "/data/uploads"
max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}" max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}"
max_image_pixels: "32M" max_image_pixels: "32M"
dynamic_thumbnails: false dynamic_thumbnails: false

View File

@ -379,11 +379,12 @@ The following fields are returned in the JSON response body:
- ``total`` - Number of rooms. - ``total`` - Number of rooms.
List media of an user List media of a user
================================ ====================
Gets a list of all local media that a specific ``user_id`` has created. Gets a list of all local media that a specific ``user_id`` has created.
The response is ordered by creation date descending and media ID descending. By default, the response is ordered by descending creation date and ascending media ID.
The newest media is on top. The newest media is on top. You can change the order with parameters
``order_by`` and ``dir``.
The API is:: The API is::
@ -440,6 +441,35 @@ The following parameters should be set in the URL:
denoting the offset in the returned results. This should be treated as an opaque value and denoting the offset in the returned results. This should be treated as an opaque value and
not explicitly set to anything other than the return value of ``next_token`` from a previous call. not explicitly set to anything other than the return value of ``next_token`` from a previous call.
Defaults to ``0``. Defaults to ``0``.
- ``order_by`` - The method by which to sort the returned list of media.
If the ordered field has duplicates, the second order is always by ascending ``media_id``,
which guarantees a stable ordering. Valid values are:
- ``media_id`` - Media are ordered alphabetically by ``media_id``.
- ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with.
- ``created_ts`` - Media are ordered by when the content was uploaded in ms.
Smallest to largest. This is the default.
- ``last_access_ts`` - Media are ordered by when the content was last accessed in ms.
Smallest to largest.
- ``media_length`` - Media are ordered by length of the media in bytes.
Smallest to largest.
- ``media_type`` - Media are ordered alphabetically by MIME-type.
- ``quarantined_by`` - Media are ordered alphabetically by the user ID that
initiated the quarantine request for this media.
- ``safe_from_quarantine`` - Media are ordered by the status if this media is safe
from quarantining.
- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards.
Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``.
If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top
(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``).
Caution. The database only has indexes on the columns ``media_id``,
``user_id`` and ``created_ts``. This means that if a different sort order is used
(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``,
``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the
database, especially for large environments.
**Response** **Response**

View File

@ -9,23 +9,23 @@ of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root (443) to Matrix clients without needing to run Synapse with root
privileges. privileges.
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise` You should configure your reverse proxy to forward requests to `/_matrix` or
the requested URI in any way (for example, by decoding `%xx` escapes). `/_synapse/client` to Synapse, and have it set the `X-Forwarded-For` and
Beware that Apache *will* canonicalise URIs unless you specify `X-Forwarded-Proto` request headers.
`nocanon`.
When setting up a reverse proxy, remember that Matrix clients and other You should remember that Matrix clients and other Matrix servers do not
Matrix servers do not necessarily need to connect to your server via the necessarily need to connect to your server via the same server name or
same server name or port. Indeed, clients will use port 443 by default, port. Indeed, clients will use port 443 by default, whereas servers default to
whereas servers default to port 8448. Where these are different, we port 8448. Where these are different, we refer to the 'client port' and the
refer to the 'client port' and the 'federation port'. See [the Matrix 'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation. [delegate.md](<delegate.md>) for instructions on setting up delegation.
Endpoints that are part of the standardised Matrix specification are **NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
located under `/_matrix`, whereas endpoints specific to Synapse are the requested URI in any way (for example, by decoding `%xx` escapes).
located under `/_synapse/client`. Beware that Apache *will* canonicalise URIs unless you specify
`nocanon`.
Let's assume that we expect clients to connect to our server at Let's assume that we expect clients to connect to our server at
`https://matrix.example.com`, and other servers to connect at `https://matrix.example.com`, and other servers to connect at
@ -52,6 +52,7 @@ server {
location ~* ^(\/_matrix|\/_synapse\/client) { location ~* ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008; proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-For $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
# Nginx by default only allows file uploads up to 1M in size # Nginx by default only allows file uploads up to 1M in size
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
client_max_body_size 50M; client_max_body_size 50M;
@ -102,6 +103,7 @@ example.com:8448 {
SSLEngine on SSLEngine on
ServerName matrix.example.com; ServerName matrix.example.com;
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
@ -113,6 +115,7 @@ example.com:8448 {
SSLEngine on SSLEngine on
ServerName example.com; ServerName example.com;
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
AllowEncodedSlashes NoDecode AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
@ -134,6 +137,9 @@ example.com:8448 {
``` ```
frontend https frontend https
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src]
# Matrix client traffic # Matrix client traffic
acl matrix-host hdr(host) -i matrix.example.com acl matrix-host hdr(host) -i matrix.example.com
@ -144,6 +150,10 @@ frontend https
frontend matrix-federation frontend matrix-federation
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
http-request set-header X-Forwarded-Proto https if { ssl_fc }
http-request set-header X-Forwarded-Proto http if !{ ssl_fc }
http-request set-header X-Forwarded-For %[src]
default_backend matrix default_backend matrix
backend matrix backend matrix

View File

@ -25,7 +25,7 @@ well as some specific methods:
* `check_username_for_spam` * `check_username_for_spam`
* `check_registration_for_spam` * `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs) The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class. are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to The `ModuleApi` class provides a way for the custom spam checker class to

View File

@ -4,6 +4,7 @@ AssertPathExists=/etc/matrix-synapse/workers/%i.yaml
# This service should be restarted when the synapse target is restarted. # This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target PartOf=matrix-synapse.target
ReloadPropagatedFrom=matrix-synapse.target
# if this is started at the same time as the main, let the main process start # if this is started at the same time as the main, let the main process start
# first, to initialise the database schema. # first, to initialise the database schema.

View File

@ -3,6 +3,7 @@ Description=Synapse master
# This service should be restarted when the synapse target is restarted. # This service should be restarted when the synapse target is restarted.
PartOf=matrix-synapse.target PartOf=matrix-synapse.target
ReloadPropagatedFrom=matrix-synapse.target
[Service] [Service]
Type=notify Type=notify

View File

@ -220,10 +220,6 @@ Asks the server for the current position of all streams.
Acknowledge receipt of some federation data Acknowledge receipt of some federation data
#### REMOVE_PUSHER (C)
Inform the server a pusher should be removed
### REMOTE_SERVER_UP (S, C) ### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online. Inform other processes that a remote server may have come back online.

View File

@ -22,7 +22,7 @@ import logging
import sys import sys
import time import time
import traceback import traceback
from typing import Dict, Optional, Set from typing import Dict, Iterable, Optional, Set
import yaml import yaml
@ -629,7 +629,13 @@ class Porter(object):
await self._setup_state_group_id_seq() await self._setup_state_group_id_seq()
await self._setup_user_id_seq() await self._setup_user_id_seq()
await self._setup_events_stream_seqs() await self._setup_events_stream_seqs()
await self._setup_device_inbox_seq() await self._setup_sequence(
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
)
await self._setup_sequence(
"account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data"))
await self._setup_sequence("receipts_sequence", ("receipts_linearized", ))
await self._setup_auth_chain_sequence()
# Step 3. Get tables. # Step 3. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
@ -854,7 +860,7 @@ class Porter(object):
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self): async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
) )
@ -868,7 +874,7 @@ class Porter(object):
await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
async def _setup_user_id_seq(self): async def _setup_user_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.runInteraction( curr_id = await self.sqlite_store.db_pool.runInteraction(
"setup_user_id_seq", find_max_generated_user_id_localpart "setup_user_id_seq", find_max_generated_user_id_localpart
) )
@ -877,9 +883,9 @@ class Porter(object):
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
async def _setup_events_stream_seqs(self): async def _setup_events_stream_seqs(self) -> None:
"""Set the event stream sequences to the correct values. """Set the event stream sequences to the correct values.
""" """
@ -908,35 +914,46 @@ class Porter(object):
(curr_backward_id + 1,), (curr_backward_id + 1,),
) )
return await self.postgres_store.db_pool.runInteraction( await self.postgres_store.db_pool.runInteraction(
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
) )
async def _setup_device_inbox_seq(self): async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
"""Set the device inbox sequence to the correct value. """Set a sequence to the correct value.
""" """
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( current_stream_ids = []
table="device_inbox", for stream_id_table in stream_id_tables:
keyvalues={}, max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
retcol="COALESCE(MAX(stream_id), 1)", table=stream_id_table,
allow_none=True, keyvalues={},
) retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
current_stream_ids.append(max_stream_id)
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( next_id = max(current_stream_ids) + 1
table="device_federation_outbox",
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
next_id = max(curr_local_id, curr_federation_id) + 1 def r(txn):
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
txn.execute(sql + " %s", (next_id, ))
await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
)
def r(txn): def r(txn):
txn.execute( txn.execute(
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id,),
) )
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) await self.postgres_store.db_pool.runInteraction(
"_setup_event_auth_chain_id", r,
)
############################################## ##############################################

View File

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.28.0rc1" __version__ = "1.28.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -210,7 +210,9 @@ def start(config_options):
config.update_user_directory = False config.update_user_directory = False
config.run_background_tasks = False config.run_background_tasks = False
config.start_pushers = False config.start_pushers = False
config.pusher_shard_config.instances = []
config.send_federation = False config.send_federation = False
config.federation_shard_config.instances = []
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@ -645,9 +645,6 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)
async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
@cache_in_self @cache_in_self
def get_replication_data_handler(self): def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self) return GenericWorkerReplicationHandler(self)
@ -922,22 +919,6 @@ def start(config_options):
# For other worker types we force this to off. # For other worker types we force this to off.
config.appservice.notify_appservices = False config.appservice.notify_appservices = False
if config.worker_app == "synapse.app.pusher":
if config.server.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``start_pushers: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.server.start_pushers = True
else:
# For other worker types we force this to off.
config.server.start_pushers = False
if config.worker_app == "synapse.app.user_dir": if config.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory: if config.server.update_user_directory:
sys.stderr.write( sys.stderr.write(
@ -954,22 +935,6 @@ def start(config_options):
# For other worker types we force this to off. # For other worker types we force this to off.
config.server.update_user_directory = False config.server.update_user_directory = False
if config.worker_app == "synapse.app.federation_sender":
if config.worker.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``send_federation: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.worker.send_federation = True
else:
# For other worker types we force this to off.
config.worker.send_federation = False
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
hs = GenericWorkerServer( hs = GenericWorkerServer(

View File

@ -844,22 +844,23 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool: def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.""" """Whether this instance is responsible for handling the given key."""
# If multiple instances are not defined we always return true # If no instances are defined we assume some other worker is handling
if not self.instances or len(self.instances) == 1: # this.
return True if not self.instances:
return False
return self.get_instance(key) == instance_name return self._get_instance(key) == instance_name
def get_instance(self, key: str) -> str: def _get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key. """Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance Note: For federation sending and pushers the config for which instance
is sending is known only to the sender instance if there is only one. is sending is known only to the sender instance, so we don't expose this
Therefore `should_handle` should be used where possible. method by default.
""" """
if not self.instances: if not self.instances:
return "master" raise Exception("Unknown worker")
if len(self.instances) == 1: if len(self.instances) == 1:
return self.instances[0] return self.instances[0]
@ -876,4 +877,21 @@ class ShardedWorkerHandlingConfig:
return self.instances[remainder] return self.instances[remainder]
@attr.s
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
"""A version of `ShardedWorkerHandlingConfig` that is used for config
options where all instances know which instances are responsible for the
sharded work.
"""
def __attrs_post_init__(self):
# We require that `self.instances` is non-empty.
if not self.instances:
raise Exception("Got empty list of instances for shard config")
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key."""
return self._get_instance(key)
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig:
instances: List[str] instances: List[str]
def __init__(self, instances: List[str]) -> None: ... def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ... def should_handle(self, instance_name: str, key: str) -> bool: ...
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ... def get_instance(self, key: str) -> str: ...

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ShardedWorkerHandlingConfig from ._base import Config
class PushConfig(Config): class PushConfig(Config):
@ -27,9 +27,6 @@ class PushConfig(Config):
"group_unread_count_by_room", True "group_unread_count_by_room", True
) )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# There was a a 'redact_content' setting but mistakenly read from the # There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log, # 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade. # but do not honour it to avoid nasty surprises when people upgrade.

View File

@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config):
def generate_config_section(self, data_dir_path, **kwargs): def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store") media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
formatted_thumbnail_sizes = "".join( formatted_thumbnail_sizes = "".join(
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES

View File

@ -397,7 +397,6 @@ class ServerConfig(Config):
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/": if self.public_baseurl[-1] != "/":
self.public_baseurl += "/" self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit, # (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before # for testing. The value defines the number of milliseconds to pause before

View File

@ -17,9 +17,28 @@ from typing import List, Union
import attr import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from ._base import (
Config,
ConfigError,
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
)
from .server import ListenerConfig, parse_listener_def from .server import ListenerConfig, parse_listener_def
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
The send_federation config option must be disabled in the main
synapse process before they can be run in a separate worker.
Please add ``send_federation: false`` to the main config
"""
_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """
The start_pushers config option must be disabled in the main
synapse process before they can be run in a separate worker.
Please add ``start_pushers: false`` to the main config
"""
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config """Helper for allowing parsing a string or list of strings to a config
@ -103,6 +122,7 @@ class WorkerConfig(Config):
self.worker_replication_secret = config.get("worker_replication_secret", None) self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.instance_name = self.worker_name or "master"
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@ -118,12 +138,41 @@ class WorkerConfig(Config):
) )
) )
# Whether to send federation traffic out in this process. This only # Handle federation sender configuration.
# applies to some federation traffic, and so shouldn't be used to #
# "disable" federation # There are two ways of configuring which instances handle federation
self.send_federation = config.get("send_federation", True) # sending:
# 1. The old way where "send_federation" is set to false and running a
# `synapse.app.federation_sender` worker app.
# 2. Specifying the workers sending federation in
# `federation_sender_instances`.
#
federation_sender_instances = config.get("federation_sender_instances") or [] send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances")
if federation_sender_instances is None:
# Default to an empty list, which means "another, unknown, worker is
# responsible for it".
federation_sender_instances = []
# If no federation sender instances are set we check if
# `send_federation` is set, which means use master
if send_federation:
federation_sender_instances = ["master"]
if self.worker_app == "synapse.app.federation_sender":
if send_federation:
# If we're running federation senders, and not using
# `federation_sender_instances`, then we should have
# explicitly set `send_federation` to false.
raise ConfigError(
_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR
)
federation_sender_instances = [self.worker_name]
self.send_federation = self.instance_name in federation_sender_instances
self.federation_shard_config = ShardedWorkerHandlingConfig( self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances federation_sender_instances
) )
@ -164,7 +213,37 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `receipts` messages." "Must only specify one instance to handle `receipts` messages."
) )
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) if len(self.writers.events) == 0:
raise ConfigError("Must specify at least one instance to handle `events`.")
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)
# Handle sharded push
start_pushers = config.get("start_pushers", True)
pusher_instances = config.get("pusher_instances")
if pusher_instances is None:
# Default to an empty list, which means "another, unknown, worker is
# responsible for it".
pusher_instances = []
# If no pushers instances are set we check if `start_pushers` is
# set, which means use master
if start_pushers:
pusher_instances = ["master"]
if self.worker_app == "synapse.app.pusher":
if start_pushers:
# If we're running pushers, and not using
# `pusher_instances`, then we should have explicitly set
# `start_pushers` to false.
raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR)
pusher_instances = [self.instance_name]
self.start_pushers = self.instance_name in pusher_instances
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# Whether this worker should run background tasks or not. # Whether this worker should run background tasks or not.
# #

View File

@ -120,6 +120,11 @@ class DeactivateAccountHandler(BaseHandler):
await self.store.user_set_password_hash(user_id, None) await self.store.user_set_password_hash(user_id, None)
# Most of the pushers will have been deleted when we logged out the
# associated devices above, but we still need to delete pushers not
# associated with devices, e.g. email pushers.
await self.store.delete_all_pushers_for_user(user_id)
# Add the user to a table of users pending deactivation (ie. # Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of) # removal from all the rooms they're a member of)
await self.store.add_user_pending_deactivation(user_id) await self.store.add_user_pending_deactivation(user_id)

View File

@ -278,8 +278,9 @@ class SyncHandler:
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester) await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap( res = await self.response_cache.wrap_conditional(
sync_config.request_key, sync_config.request_key,
lambda result: since_token != result.next_batch,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, sync_config,
since_token, since_token,

View File

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Union
from twisted.internet import task from twisted.internet import address, task
from twisted.web.client import FileBodyProducer from twisted.web.client import FileBodyProducer
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer):
pass pass
def get_request_uri(request: IRequest) -> bytes:
"""Return the full URI that was requested by the client"""
return b"%s://%s%s" % (
b"https" if request.isSecure() else b"http",
_get_requested_host(request),
# despite its name, "request.uri" is only the path and query-string.
request.uri,
)
def _get_requested_host(request: IRequest) -> bytes:
hostname = request.getHeader(b"host")
if hostname:
return hostname
# no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address]
hostname = host.host.encode("ascii")
if request.isSecure() and host.port == 443:
# default port for https
return hostname
if not request.isSecure() and host.port == 80:
# default port for http
return hostname
return b"%s:%i" % (
hostname,
host.port,
)
def get_request_user_agent(request: IRequest, default: str = "") -> str: def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.""" """Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header. # There could be raw utf-8 bytes in the User-Agent header.

View File

@ -16,6 +16,10 @@ import logging
import time import time
from typing import Optional, Union from typing import Optional, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
@ -333,27 +337,78 @@ class SynapseRequest(Request):
class XForwardedForRequest(SynapseRequest): class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw): """Request object which honours proxy headers
SynapseRequest.__init__(self, *args, **kw)
""" Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
Add a layer on top of another request that only uses the value of an information from request headers.
X-Forwarded-For header as the result of C{getClientIP}.
""" """
def getClientIP(self): # the client IP and ssl flag, as extracted from the headers.
""" _forwarded_for = None # type: Optional[_XForwardedForAddress]
@return: The client address (the first address) in the value of the _forwarded_https = False # type: bool
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}. def requestReceived(self, command, path, version):
""" # this method is called by the Channel once the full request has been
return ( # received, to dispatch the request to a resource.
self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] # We can use it to set the IP address and protocol according to the
.split(b",")[0] # headers.
.strip() self._process_forwarded_headers()
.decode("ascii") return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
# for now, we just use the first x-forwarded-for header. Really, we ought
# to start from the client IP address, and check whether it is trusted; if it
# is, work backwards through the headers until we find an untrusted address.
# see https://github.com/matrix-org/synapse/issues/9471
self._forwarded_for = _XForwardedForAddress(
headers[0].split(b",")[0].strip().decode("ascii")
) )
# if we got an x-forwarded-for header, also look for an x-forwarded-proto header
header = self.getHeader(b"x-forwarded-proto")
if header is not None:
self._forwarded_https = header.lower() == b"https"
else:
# this is done largely for backwards-compatibility so that people that
# haven't set an x-forwarded-proto header don't get a redirect loop.
logger.warning(
"forwarded request lacks an x-forwarded-proto header: assuming https"
)
self._forwarded_https = True
def isSecure(self):
if self._forwarded_https:
return True
return super().isSecure()
def getClientIP(self) -> str:
"""
Return the IP address of the client who submitted this request.
This method is deprecated. Use getClientAddress() instead.
"""
if self._forwarded_for is not None:
return self._forwarded_for.host
return super().getClientIP()
def getClientAddress(self) -> IAddress:
"""
Return the address of the client who submitted this request.
"""
if self._forwarded_for is not None:
return self._forwarded_for
return super().getClientAddress()
@implementer(IAddress)
@attr.s(frozen=True, slots=True)
class _XForwardedForAddress:
host = attr.ib(type=str)
class SynapseSite(Site): class SynapseSite(Site):
""" """

View File

@ -74,6 +74,7 @@ class HttpPusher(Pusher):
self.timed_call = None self.timed_call = None
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
self.data = pusher_config.data self.data = pusher_config.data
if self.data is None: if self.data is None:
@ -304,7 +305,7 @@ class HttpPusher(Pusher):
) )
else: else:
logger.info("Pushkey %s was rejected: removing", pk) logger.info("Pushkey %s was rejected: removing", pk)
await self.hs.remove_pusher(self.app_id, pk, self.user_id) await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id)
return True return True
async def _build_notification_dict( async def _build_notification_dict(

View File

@ -19,12 +19,14 @@ from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge from prometheus_client import Gauge
from synapse.api.errors import Codes, SynapseError
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
from synapse.types import JsonDict, RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
@ -58,7 +60,6 @@ class PusherPool:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.pusher_factory = PusherFactory(hs) self.pusher_factory = PusherFactory(hs)
self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
@ -67,6 +68,16 @@ class PusherPool:
# We shard the handling of push notifications by user ID. # We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances
)
# We can only delete pushers on master.
self._remove_pusher_client = None
if hs.config.worker.worker_app:
self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client(
hs
)
# Record the last stream ID that we were poked about so we can get # Record the last stream ID that we were poked about so we can get
# changes since then. We set this to the current max stream ID on # changes since then. We set this to the current max stream ID on
@ -103,6 +114,11 @@ class PusherPool:
The newly created pusher. The newly created pusher.
""" """
if kind == "email":
email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
if email_owner != user_id:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
# create the pusher setting last_stream_ordering to the current maximum # create the pusher setting last_stream_ordering to the current maximum
@ -175,9 +191,6 @@ class PusherPool:
user_id: user to remove pushers for user_id: user to remove pushers for
access_tokens: access token *ids* to remove pushers for access_tokens: access token *ids* to remove pushers for
""" """
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id): for p in await self.store.get_pushers_by_user_id(user_id):
if p.access_token in tokens: if p.access_token in tokens:
@ -380,6 +393,12 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
await self.store.delete_pusher_by_app_id_pushkey_user_id( # We can only delete pushers on master.
app_id, pushkey, user_id if self._remove_pusher_client:
) await self._remove_pusher_client(
app_id=app_id, pushkey=pushkey, user_id=user_id
)
else:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)

View File

@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = {
"pysaml2>=4.5.0;python_version>='3.6'", "pysaml2>=4.5.0;python_version>='3.6'",
], ],
"oidc": ["authlib>=0.14.0"], "oidc": ["authlib>=0.14.0"],
# systemd-python is necessary for logging to the systemd journal via
# `systemd.journal.JournalHandler`, as is documented in
# `contrib/systemd/log_config.yaml`.
"systemd": ["systemd-python>=231"], "systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"], "url_preview": ["lxml>=3.5.0"],
"sentry": ["sentry-sdk>=0.7.2"], "sentry": ["sentry-sdk>=0.7.2"],

View File

@ -21,6 +21,7 @@ from synapse.replication.http import (
login, login,
membership, membership,
presence, presence,
push,
register, register,
send_event, send_event,
streams, streams,
@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource):
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
streams.register_servlets(hs, self) streams.register_servlets(hs, self)
account_data.register_servlets(hs, self) account_data.register_servlets(hs, self)
push.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View File

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
"""Deletes the given pusher.
Request format:
POST /_synapse/replication/remove_pusher/:user_id
{
"app_id": "<some_id>",
"pushkey": "<some_key>"
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id",)
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.pusher_pool = hs.get_pusherpool()
@staticmethod
async def _serialize_payload(app_id, pushkey, user_id):
payload = {
"app_id": app_id,
"pushkey": pushkey,
}
return payload
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
app_id = content["app_id"]
pushkey = content["pushkey"]
await self.pusher_pool.remove_pusher(app_id, pushkey, user_id)
return 200, {}
def register_servlets(hs, http_server):
ReplicationRemovePusherRestServlet(hs).register(http_server)

View File

@ -325,31 +325,6 @@ class FederationAckCommand(Command):
return "%s %s" % (self.instance_name, self.token) return "%s %s" % (self.instance_name, self.token)
class RemovePusherCommand(Command):
"""Sent by the client to request the master remove the given pusher.
Format::
REMOVE_PUSHER <app_id> <push_key> <user_id>
"""
NAME = "REMOVE_PUSHER"
def __init__(self, app_id, push_key, user_id):
self.user_id = user_id
self.app_id = app_id
self.push_key = push_key
@classmethod
def from_line(cls, line):
app_id, push_key, user_id = line.split(" ", 2)
return cls(app_id, push_key, user_id)
def to_line(self):
return " ".join((self.app_id, self.push_key, self.user_id))
class UserIpCommand(Command): class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client. """Sent periodically when a worker sees activity from a client.
@ -416,7 +391,6 @@ _COMMANDS = (
ReplicateCommand, ReplicateCommand,
UserSyncCommand, UserSyncCommand,
FederationAckCommand, FederationAckCommand,
RemovePusherCommand,
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = (
UserSyncCommand.NAME, UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME, ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME, FederationAckCommand.NAME,
RemovePusherCommand.NAME,
UserIpCommand.NAME, UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,

View File

@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import (
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand, ReplicateCommand,
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
@ -373,23 +372,6 @@ class ReplicationCommandHandler:
if self._federation_sender: if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token) self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
return self._handle_remove_pusher(cmd)
else:
return None
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self._notifier.on_new_replication_data()
def on_USER_IP( def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
@ -684,11 +666,6 @@ class ReplicationCommandHandler:
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
) )
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_user_ip( def send_user_ip(
self, self,
user_id: str, user_id: str,

View File

@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin, assert_user_is_admin,
) )
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
else:
order_by = parse_string(
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
allowed_values=(
MediaSortOrder.MEDIA_ID.value,
MediaSortOrder.UPLOAD_NAME.value,
MediaSortOrder.CREATED_TS.value,
MediaSortOrder.LAST_ACCESS_TS.value,
MediaSortOrder.MEDIA_LENGTH.value,
MediaSortOrder.MEDIA_TYPE.value,
MediaSortOrder.QUARANTINED_BY.value,
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
),
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
)
media, total = await self.store.get_local_media_by_user_paginate( media, total = await self.store.get_local_media_by_user_paginate(
start, limit, user_id start, limit, user_id, order_by, direction
) )
ret = {"media": media, "total": total} ret = {"media": media, "total": total}

View File

@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet):
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
super().register(http_server) super().register(http_server)
@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None: ) -> None:
if not self._public_baseurl:
raise SynapseError(400, "SSO requires a valid public_baseurl")
# if this isn't the expected hostname, redirect to the right one, so that we
# get our cookies back.
requested_uri = get_request_uri(request)
baseurl_bytes = self._public_baseurl.encode("utf-8")
if not requested_uri.startswith(baseurl_bytes):
# swap out the incorrect base URL for the right one.
#
# The idea here is to redirect from
# https://foo.bar/whatever/_matrix/...
# to
# https://public.baseurl/_matrix/...
#
i = requested_uri.index(b"/_matrix")
new_uri = baseurl_bytes[:-1] + requested_uri[i:]
logger.info(
"Requested URI %s is not canonical: redirecting to %s",
requested_uri.decode("utf-8", errors="replace"),
new_uri.decode("utf-8", errors="replace"),
)
request.redirect(new_uri)
finish_request(request)
return
client_redirect_url = parse_string( client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None request, "redirectUrl", required=True, encoding=None
) )

View File

@ -54,7 +54,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
from synapse.rest.synapse.client.saml2 import SAML2Resource from synapse.rest.synapse.client.saml2 import SAML2Resource
resources["/_synapse/client/saml2"] = SAML2Resource(hs) res = SAML2Resource(hs)
resources["/_synapse/client/saml2"] = res
# This is also mounted under '/_matrix' for backwards-compatibility.
# To be removed in Synapse v1.32.0.
resources["/_matrix/saml2"] = res
return resources return resources

View File

@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self.start_time = None # type: Optional[int] self.start_time = None # type: Optional[int]
self._instance_id = random_string(5) self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master" self._instance_name = config.worker.instance_name
self.version_string = version_string self.version_string = version_string
@ -758,12 +758,6 @@ class HomeServer(metaclass=abc.ABCMeta):
reconnect=True, reconnect=True,
) )
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
def should_send_federation(self) -> bool: def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?" "Should this server be sending federation traffic directly?"
return self.config.send_federation and ( return self.config.send_federation
not self.config.worker_app
or self.config.worker_app == "synapse.app.federation_sender"
)

View File

@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection from synapse.types import Collection
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0 _TXN_ID = 0
def __init__( def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine self,
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
): ):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -420,16 +422,6 @@ class DatabasePool:
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
engine, get_chain_id_txn, "event_auth_chain_id"
)
def is_running(self) -> bool: def is_running(self) -> bool:
"""Is the database pool currently running""" """Is the database pool currently running"""
return self._db_pool.running return self._db_pool.running

View File

@ -79,7 +79,7 @@ class Databases:
# If we're on a process that can persist events also # If we're on a process that can persist events also
# instantiate a `PersistEventsStore` # instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events: if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main) persist_events = PersistEventsStore(hs, database, main, db_conn)
if "state" in database_config.databases: if "state" in database_config.databases:
logger.info( logger.info(

View File

@ -42,7 +42,9 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.iterutils import batch_iter, sorted_topologically
@ -90,7 +92,11 @@ class PersistEventsStore:
""" """
def __init__( def __init__(
self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" self,
hs: "HomeServer",
db: DatabasePool,
main_data_store: "DataStore",
db_conn: Connection,
): ):
self.hs = hs self.hs = hs
self.db_pool = db self.db_pool = db
@ -474,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index( self._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -484,6 +491,7 @@ class PersistEventsStore:
cls, cls,
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -630,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids( new_chain_tuples = cls._allocate_chain_ids(
txn, txn,
db_pool, db_pool,
event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -768,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids( def _allocate_chain_ids(
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -880,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs. # Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids) txn, len(unallocated_chain_ids)
) )

View File

@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,

View File

@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
database.engine,
get_chain_id_txn,
"event_auth_chain_id",
table="event_auth_chains",
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token) self._stream_id_gen.advance(instance_name, token)

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
) )
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
get_local_media_by_user_paginate
"""
MEDIA_ID = "media_id"
UPLOAD_NAME = "upload_name"
CREATED_TS = "created_ts"
LAST_ACCESS_TS = "last_access_ts"
MEDIA_LENGTH = "media_length"
MEDIA_TYPE = "media_type"
QUARANTINED_BY = "quarantined_by"
SAFE_FROM_QUARANTINE = "safe_from_quarantine"
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
async def get_local_media_by_user_paginate( async def get_local_media_by_user_paginate(
self, start: int, limit: int, user_id: str self,
start: int,
limit: int,
user_id: str,
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media
which an user_id has uploaded which an user_id has uploaded
@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: offset in the list start: offset in the list
limit: maximum amount of media_ids to retrieve limit: maximum amount of media_ids to retrieve
user_id: fully-qualified user id user_id: fully-qualified user id
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns: Returns:
A paginated list of all metadata of user's media, A paginated list of all metadata of user's media,
plus the total count of all the user's media plus the total count of all the user's media
@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(txn): def get_local_media_by_user_paginate_txn(txn):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
if direction == "b":
order = "DESC"
else:
order = "ASC"
args = [user_id] args = [user_id]
sql = """ sql = """
SELECT COUNT(*) as total_media SELECT COUNT(*) as total_media
@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine" "safe_from_quarantine"
FROM local_media_repository FROM local_media_repository
WHERE user_id = ? WHERE user_id = ?
ORDER BY created_ts DESC, media_id DESC ORDER BY {order_by_column} {order}, media_id ASC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """.format(
order_by_column=order_by_column,
order=order,
)
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)

View File

@ -373,3 +373,46 @@ class PusherStore(PusherWorkerStore):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id "delete_pusher", delete_pusher_txn, stream_id
) )
async def delete_all_pushers_for_user(self, user_id: str) -> None:
"""Delete all pushers associated with an account."""
# We want to generate a row in `deleted_pushers` for each pusher we're
# deleting, so we fetch the list now so we can generate the appropriate
# number of stream IDs.
#
# Note: technically there could be a race here between adding/deleting
# pushers, but a) the worst case if we don't stop a pusher until the
# next restart and b) this is only called when we're deactivating an
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids):
self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
self.db_pool.simple_delete_txn(
txn,
table="pushers",
keyvalues={"user_name": user_id},
)
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
values=[
{
"stream_id": stream_id,
"app_id": pusher.app_id,
"pushkey": pusher.pushkey,
"user_id": user_id,
}
for stream_id, pusher in zip(stream_ids, pushers)
],
)
async with self._pushers_id_gen.get_next_mult(len(pushers)) as stream_ids:
await self.db_pool.runInteraction(
"delete_all_pushers_for_user", delete_pushers_txn, stream_ids
)

View File

@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries. # expensive if there are many entries.
self._user_id_seq = build_sequence_generator( self._user_id_seq = build_sequence_generator(
db_conn,
database.engine, database.engine,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
"user_id_seq", "user_id_seq",
table=None,
id_column=None,
) )
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()

View File

@ -0,0 +1,21 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We may not have deleted all pushers for deactivated accounts. Do so now.
--
-- Note: We don't bother updating the `deleted_pushers` table as it's just use
-- to stop pushers on workers, and that will happen when they get next restarted.
DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);

View File

@ -0,0 +1,19 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Delete all pushers associated with deleted devices. This is to clear up after
-- a bug where they weren't correctly deleted when using workers.
DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);

View File

@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def add_users_in_public_rooms( async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str] self, room_id: str, user_ids: Iterable[str]
) -> None: ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first """Insert entries into the users_in_public_rooms table.
user should be a local user.
Args: Args:
room_id room_id
@ -670,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows) users.update(rows)
return list(users) return list(users)
@cached()
async def get_shared_rooms_for_users( async def get_shared_rooms_for_users(
self, user_id: str, other_user_id: str self, user_id: str, other_user_id: str
) -> Set[str]: ) -> Set[str]:

View File

@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0] return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator( self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq" db_conn,
) self.database_engine,
self._state_group_seq_gen.check_consistency( get_max_state_group_txn,
db_conn, table="state_groups", id_column="id" "state_group_id_seq",
table="state_groups",
id_column="id",
) )
@cached(max_entries=10000, iterable=True) @cached(max_entries=10000, iterable=True)

View File

@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
def build_sequence_generator( def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType, get_first_callback: GetFirstCallbackType,
sequence_name: str, sequence_name: str,
table: Optional[str],
id_column: Optional[str],
stream_name: Optional[str] = None,
positive: bool = True,
) -> SequenceGenerator: ) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available """Get the best impl of SequenceGenerator available
@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite. we're on sqlite.
sequence_name: the name of a postgres sequence to use. sequence_name: the name of a postgres sequence to use.
table, id_column, stream_name, positive: If set then `check_consistency`
is called on the created sequence. See docstring for
`check_consistency` details.
""" """
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
return PostgresSequenceGenerator(sequence_name) seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else: else:
return LocalSequenceGenerator(get_first_callback) seq = LocalSequenceGenerator(get_first_callback)
if table:
assert id_column
seq.check_consistency(
db_conn=db_conn,
table=table,
id_column=id_column,
stream_name=stream_name,
positive=positive,
)
return seq

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
from twisted.internet import defer from twisted.internet import defer
@ -40,6 +40,7 @@ class ResponseCache(Generic[T]):
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet. # Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0 self.timeout_sec = timeout_ms / 1000.0
@ -101,7 +102,11 @@ class ResponseCache(Generic[T]):
self.pending_result_cache[key] = result self.pending_result_cache[key] = result
def remove(r): def remove(r):
if self.timeout_sec: should_cache = all(
func(r) for func in self.pending_conditionals.pop(key, [])
)
if self.timeout_sec and should_cache:
self.clock.call_later( self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None self.timeout_sec, self.pending_result_cache.pop, key, None
) )
@ -112,6 +117,31 @@ class ResponseCache(Generic[T]):
result.addBoth(remove) result.addBoth(remove)
return result.observe() return result.observe()
def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
self.pending_conditionals.setdefault(key, set()).add(conditional)
def wrap_conditional(
self,
key: T,
should_cache: Callable[[Any], bool],
callback: "Callable[..., Any]",
*args: Any,
**kwargs: Any
) -> defer.Deferred:
"""The same as wrap(), but adds a conditional to the final execution.
When the final execution completes, *all* conditionals need to return True for it to properly cache,
else it'll not be cached in a timed fashion.
"""
# See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
# python, adding a key immediately in the same execution thread will not cause a race condition.
result = self.get(key)
if not result or isinstance(result, defer.Deferred) and not result.called:
self.add_conditional(key, should_cache)
return self.wrap(key, callback, *args, **kwargs)
def wrap( def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred: ) -> defer.Deferred:

View File

@ -21,6 +21,7 @@ import pkg_resources
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -100,12 +101,19 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token) self.hs.get_datastore().get_user_by_access_token(self.access_token)
) )
token_id = user_tuple.token_id self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher.
self.get_success(
hs.get_datastore().user_add_threepid(
self.user_id, "email", "a@example.com", 0, 0
)
)
self.pusher = self.get_success( self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
user_id=self.user_id, user_id=self.user_id,
access_token=token_id, access_token=self.token_id,
kind="email", kind="email",
app_id="m.email", app_id="m.email",
app_display_name="Email Notifications", app_display_name="Email Notifications",
@ -116,6 +124,28 @@ class EmailPusherTests(HomeserverTestCase):
) )
) )
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
their email.
"""
with self.assertRaises(SynapseError) as cm:
self.get_success_or_raise(
self.hs.get_pusherpool().add_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name="b@example.com",
pushkey="b@example.com",
lang=None,
data={},
)
)
self.assertEqual(400, cm.exception.code)
self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
def test_simple_sends_email(self): def test_simple_sends_email(self):
# Create a simple room with two users # Create a simple room with two users
room = self.helper.create_room_as(self.user_id, tok=self.access_token) room = self.helper.create_room_as(self.user_id, tok=self.access_token)

View File

@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
# enable federation sending on the worker # enable federation sending on the worker
config = super()._get_worker_hs_config() config = super()._get_worker_hs_config()
# TODO: make it so we don't need both of these # TODO: make it so we don't need both of these
config["send_federation"] = True config["send_federation"] = False
config["worker_app"] = "synapse.app.federation_sender" config["worker_app"] = "synapse.app.federation_sender"
return config return config

View File

@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase):
def default_config(self) -> dict: def default_config(self) -> dict:
config = super().default_config() config = super().default_config()
config["worker_app"] = "synapse.app.federation_sender" config["worker_app"] = "synapse.app.federation_sender"
config["send_federation"] = True config["send_federation"] = False
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):

View File

@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
{"send_federation": True}, {"send_federation": False},
federation_http_client=mock_client, federation_http_client=mock_client,
) )

View File

@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs( self.make_worker_hs(
"synapse.app.pusher", "synapse.app.pusher",
{"start_pushers": True}, {"start_pushers": False},
proxied_blacklisted_http_client=http_client_mock, proxied_blacklisted_http_client=http_client_mock,
) )

View File

@ -18,7 +18,7 @@ import hmac
import json import json
import urllib.parse import urllib.parse
from binascii import unhexlify from binascii import unhexlify
from typing import Optional from typing import List, Optional
from mock import Mock from mock import Mock
@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync
from synapse.types import JsonDict from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["media"]), 10) self.assertEqual(len(channel.json_body["media"]), 10)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def test_limit_is_negative(self): def test_invalid_parameter(self):
""" """
Testing that a negative limit parameter returns a 400 If parameters are invalid, an error is returned.
""" """
# unkown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
channel = self.make_request(
"GET",
self.url + "?dir=bar",
access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative limit
channel = self.make_request( channel = self.make_request(
"GET", "GET",
self.url + "?limit=-5", self.url + "?limit=-5",
@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self): # negative from
"""
Testing that a negative from parameter returns a 400
"""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
self.url + "?from=-5", self.url + "?from=-5",
@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 20 number_media = 20
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
# `next_token` does not appear # `next_token` does not appear
# Number of results is the number of entries # Number of results is the number of entries
@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
number_media = 5 number_media = 5
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
self._create_media(other_user_tok, number_media) self._create_media_for_user(other_user_tok, number_media)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"]) self._check_fields(channel.json_body["media"])
def _create_media(self, user_token, number_media): def test_order_by(self):
"""
Testing order list with parameter `order_by`
"""
other_user_tok = self.login("user", "pass")
# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
image_data1 = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
b"47494638376101000100800100000000"
b"ffffff2c00000000010001000002024c"
b"01003b"
)
# Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
image_data3 = unhexlify(
b"424d3a0000000000000036000000280000000100000001000000"
b"0100180000000000040000000000000000000000000000000000"
b"0000"
)
# create media and make sure they do not have the same timestamp
media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png")
self.pump(1.0)
media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif")
self.pump(1.0)
media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp")
self.pump(1.0)
# Mark one media as safe from quarantine.
self.get_success(self.store.mark_local_media_as_safe(media2))
# Quarantine one media
self.get_success(
self.store.quarantine_media_by_id("test", media3, self.admin_user)
)
# order by default ("created_ts")
# default is backwards
self._order_test([media3, media2, media1], None)
self._order_test([media1, media2, media3], None, "f")
self._order_test([media3, media2, media1], None, "b")
# sort by media_id
sorted_media = sorted([media1, media2, media3], reverse=False)
sorted_media_reverse = sorted(sorted_media, reverse=True)
# order by media_id
self._order_test(sorted_media, "media_id")
self._order_test(sorted_media, "media_id", "f")
self._order_test(sorted_media_reverse, "media_id", "b")
# order by upload_name
self._order_test([media3, media2, media1], "upload_name")
self._order_test([media3, media2, media1], "upload_name", "f")
self._order_test([media1, media2, media3], "upload_name", "b")
# order by media_type
# result is ordered by media_id
# because of uploaded media_type is always 'application/json'
self._order_test(sorted_media, "media_type")
self._order_test(sorted_media, "media_type", "f")
self._order_test(sorted_media, "media_type", "b")
# order by media_length
self._order_test([media2, media3, media1], "media_length")
self._order_test([media2, media3, media1], "media_length", "f")
self._order_test([media1, media3, media2], "media_length", "b")
# order by created_ts
self._order_test([media1, media2, media3], "created_ts")
self._order_test([media1, media2, media3], "created_ts", "f")
self._order_test([media3, media2, media1], "created_ts", "b")
# order by last_access_ts
self._order_test([media1, media2, media3], "last_access_ts")
self._order_test([media1, media2, media3], "last_access_ts", "f")
self._order_test([media3, media2, media1], "last_access_ts", "b")
# order by quarantined_by
# one media is in quarantine, others are ordered by media_ids
# Different sort order of SQlite and PostreSQL
# If a media is not in quarantine `quarantined_by` is NULL
# SQLite considers NULL to be smaller than any other value.
# PostreSQL considers NULL to be larger than any other value.
# self._order_test(sorted([media1, media2]) + [media3], "quarantined_by")
# self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f")
# self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b")
# order by safe_from_quarantine
# one media is safe from quarantine, others are ordered by media_ids
self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine")
self._order_test(
sorted([media1, media3]) + [media2], "safe_from_quarantine", "f"
)
self._order_test(
[media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
)
def _create_media_for_user(self, user_token: str, number_media: int):
""" """
Create a number of media for a specific user Create a number of media for a specific user
Args:
user_token: Access token of the user
number_media: Number of media to be created for the user
""" """
upload_resource = self.media_repo.children[b"upload"]
for i in range(number_media): for i in range(number_media):
# file size is 67 Byte # file size is 67 Byte
image_data = unhexlify( image_data = unhexlify(
@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
b"0a2db40000000049454e44ae426082" b"0a2db40000000049454e44ae426082"
) )
# Upload some media into the room self._create_media_and_access(user_token, image_data)
self.helper.upload_media(
upload_resource, image_data, tok=user_token, expect_code=200
)
def _check_fields(self, content): def _create_media_and_access(
"""Checks that all attributes are present in content""" self,
user_token: str,
image_data: bytes,
filename: str = "image1.png",
) -> str:
"""
Create one media for a specific user, access and returns `media_id`
Args:
user_token: Access token of the user
image_data: binary data of image
filename: The filename of the media to be uploaded
Returns:
The ID of the newly created media.
"""
upload_resource = self.media_repo.children[b"upload"]
download_resource = self.media_repo.children[b"download"]
# Upload some media into the room
response = self.helper.upload_media(
upload_resource, image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
media_id = server_and_media_id.split("/")[1]
# Try to access a media and to create `last_access_ts`
channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=user_token,
)
self.assertEqual(
200,
channel.code,
msg=(
"Expected to receive a 200 on accessing media: %s" % server_and_media_id
),
)
return media_id
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
"""
for m in content: for m in content:
self.assertIn("media_id", m) self.assertIn("media_id", m)
self.assertIn("media_type", m) self.assertIn("media_type", m)
@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertIn("quarantined_by", m) self.assertIn("quarantined_by", m)
self.assertIn("safe_from_quarantine", m) self.assertIn("safe_from_quarantine", m)
def _order_test(
self,
expected_media_list: List[str],
order_by: Optional[str],
dir: Optional[str] = None,
):
"""Request the list of media in a certain order. Assert that order is what
we expect
Args:
expected_media_list: The list of media_ids in the order we expect to get
back from the server
order_by: The type of ordering to give the server
dir: The direction of ordering to give the server
"""
url = self.url + "?"
if order_by is not None:
url += "order_by=%s&" % (order_by,)
if dir is not None and dir in ("b", "f"):
url += "dir=%s" % (dir,)
channel = self.make_request(
"GET",
url.encode("ascii"),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
self.assertEqual(expected_media_list, returned_order)
self._check_fields(channel.json_body["media"])
class UserTokenRestTestCase(unittest.HomeserverTestCase): class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Test for /_synapse/admin/v1/users/<user>/login""" """Test for /_synapse/admin/v1/users/<user>/login"""

View File

@ -15,7 +15,7 @@
import time import time
import urllib.parse import urllib.parse
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlencode from urllib.parse import urlencode
from mock import Mock from mock import Mock
@ -47,8 +47,14 @@ except ImportError:
HAS_JWT = False HAS_JWT = False
# public_base_url used in some tests # synapse server name: used to populate public_baseurl in some tests
BASE_URL = "https://synapse/" SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# public_baseurl for some tests. It uses an http:// scheme because
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://....
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
# CAS server used in some tests # CAS server used in some tests
CAS_SERVER = "https://fake.test" CAS_SERVER = "https://fake.test"
@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_multi_sso_redirect(self): def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self.make_request( channel = self._make_sso_redirect_request(False, None)
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0] uri = channel.headers.getRawHeaders("Location")[0]
@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_client_idp_redirect_msc2858_disabled(self): def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "oidc")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@override_config({"experimental_features": {"msc2858_enabled": True}}) @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self): def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "xxx")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@override_config({"experimental_features": {"msc2858_enabled": True}}) @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self): def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self.make_request( channel = self._make_sso_redirect_request(True, "oidc")
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server # it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
):
"""Send a request to /_matrix/client/r0/login/sso/redirect
... or the unstable equivalent
... possibly specifying an IDP provider
"""
endpoint = (
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect"
if unstable_endpoint
else "/_matrix/client/r0/login/sso/redirect"
)
if idp_prov is not None:
endpoint += "/" + idp_prov
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
return self.make_request(
"GET",
endpoint,
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)],
)
@staticmethod @staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = " prefix = key + " = "

View File

@ -542,13 +542,30 @@ class RestHelper:
if client_redirect_url: if client_redirect_url:
params["redirectUrl"] = client_redirect_url params["redirectUrl"] = client_redirect_url
# hit the redirect url (which will issue a cookie and state) # hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
self.site, self.site,
"GET", "GET",
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
) )
assert channel.code == 302
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
location = channel.headers.getRawHeaders("Location")[0]
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
custom_headers=[
("Host", parts[1]),
],
)
assert channel.code == 302 assert channel.code == 302
channel.extract_cookies(cookies) channel.extract_cookies(cookies)

View File

@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
def default_config(self): def default_config(self):
config = super().default_config() config = super().default_config()
config["public_baseurl"] = "https://synapse.test"
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
# False, so synapse will see the requested uri as http://..., so using http in
# the public_baseurl stops Synapse trying to redirect to https.
config["public_baseurl"] = "http://synapse.test"
if HAS_OIDC: if HAS_OIDC:
# we enable OIDC as a way of testing SSO flows # we enable OIDC as a way of testing SSO flows

View File

@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users A room should show up in the shared list of rooms between two users
if it is public. if it is public.
""" """
u1 = self.register_user("user1", "pass") self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass")
room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
def test_shared_room_list_private(self): def test_shared_room_list_private(self):
""" """
A room should show up in the shared list of rooms between two users A room should show up in the shared list of rooms between two users
if it is private. if it is private.
""" """
u1 = self.register_user("user1", "pass") self._check_shared_rooms_with(
u1_token = self.login(u1, "pass") room_one_is_public=False, room_two_is_public=False
u2 = self.register_user("user2", "pass") )
u2_token = self.login(u2, "pass")
room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room)
def test_shared_room_list_mixed(self): def test_shared_room_list_mixed(self):
""" """
The shared room list between two users should contain both public and private The shared room list between two users should contain both public and private
rooms. rooms.
""" """
self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
def _check_shared_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
):
"""Checks that shared public or private rooms between two users appear in
their shared room lists
"""
u1 = self.register_user("user1", "pass") u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass") u1_token = self.login(u1, "pass")
u2 = self.register_user("user2", "pass") u2 = self.register_user("user2", "pass")
u2_token = self.login(u2, "pass") u2_token = self.login(u2, "pass")
room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) # Create a room. user1 invites user2, who joins
room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) room_id_one = self.helper.create_room_as(
self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) u1, is_public=room_one_is_public, tok=u1_token
self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) )
self.helper.join(room_public, user=u2, tok=u2_token) self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token)
self.helper.join(room_private, user=u1, tok=u1_token) self.helper.join(room_id_one, user=u2, tok=u2_token)
# Check shared rooms from user1's perspective.
# We should see the one room in common
channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 1)
self.assertEquals(channel.json_body["joined"][0], room_id_one)
# Create another room and invite user2 to it
room_id_two = self.helper.create_room_as(
u1, is_public=room_two_is_public, tok=u1_token
)
self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token)
self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
channel = self._get_shared_rooms(u1_token, u2) channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 2) self.assertEquals(len(channel.json_body["joined"]), 2)
self.assertTrue(room_public in channel.json_body["joined"]) for room_id_id in channel.json_body["joined"]:
self.assertTrue(room_private in channel.json_body["joined"]) self.assertIn(room_id_id, [room_id_one, room_id_two])
def test_shared_room_list_after_leave(self): def test_shared_room_list_after_leave(self):
""" """
@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token) self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
channel = self._get_shared_rooms(u1_token, u2)
self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1
channel = self._get_shared_rooms(u2_token, u1) channel = self._get_shared_rooms(u2_token, u1)
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
self.assertEquals(len(channel.json_body["joined"]), 0) self.assertEquals(len(channel.json_body["joined"]), 0)

View File

@ -124,7 +124,11 @@ class FakeChannel:
return address.IPv4Address("TCP", self._ip, 3423) return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self): def getHost(self):
return None # this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)
def isSecure(self):
return False
@property @property
def transport(self): def transport(self):

View File

@ -114,7 +114,6 @@ def default_config(name, parse=False):
"server_name": name, "server_name": name,
"send_federation": False, "send_federation": False,
"media_store_path": "media", "media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config # the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy # parser happy
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",