Merge branch 'develop' of github.com:matrix-org/synapse into anoa/remote_public_rooms_list_errors
* 'develop' of github.com:matrix-org/synapse: (79 commits) Reduce the number of calls to `resource.getrusage` (#7183) Remove some `run_in_background` calls in replication code (#7203) Revert "Revert "Merge pull request #7153 from matrix-org/babolivier/sso_whitelist_login_fallback"" Revert "Revert "Improve the UX of the login fallback when using SSO (#7152)"" Revert "Merge pull request #7153 from matrix-org/babolivier/sso_whitelist_login_fallback" Revert "Improve the UX of the login fallback when using SSO (#7152)" tweak changelog 1.12.3 Update docstring per review comments Fix device list update stream ids going backward (#7158) Fix the debian build in a better way. (#7212) Fix changelog wording 1.12.2 Pin Pillow>=4.3.0,<7.1.0 to fix dep issue 1.12.1 review comment 1.12.1 Support SAML in the user interactive authentication workflow. (#7102) Allow admins to create aliases when they are not in the room (#7191) Update postgres.md (#7119) ...pull/6899/head
commit
a6c3a619e9
41
CHANGES.md
41
CHANGES.md
|
@ -1,3 +1,44 @@
|
|||
Next version
|
||||
============
|
||||
|
||||
* A new template (`sso_auth_confirm.html`) was added to Synapse. If your Synapse
|
||||
is configured to use SSO and a custom `sso_redirect_confirm_template_dir`
|
||||
configuration then this template will need to be duplicated into that
|
||||
directory.
|
||||
|
||||
Synapse 1.12.3 (2020-04-03)
|
||||
===========================
|
||||
|
||||
- Remove the the pin to Pillow 7.0 which was introduced in Synapse 1.12.2, and
|
||||
correctly fix the issue with building the Debian packages. ([\#7212](https://github.com/matrix-org/synapse/issues/7212))
|
||||
|
||||
Synapse 1.12.2 (2020-04-02)
|
||||
===========================
|
||||
|
||||
This release works around [an
|
||||
issue](https://github.com/matrix-org/synapse/issues/7208) with building the
|
||||
debian packages.
|
||||
|
||||
No other significant changes since 1.12.1.
|
||||
|
||||
>>>>>>> master
|
||||
|
||||
Synapse 1.12.1 (2020-04-02)
|
||||
===========================
|
||||
|
||||
No significant changes since 1.12.1rc1.
|
||||
|
||||
|
||||
Synapse 1.12.1rc1 (2020-03-31)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix starting workers when federation sending not split out. ([\#7133](https://github.com/matrix-org/synapse/issues/7133)). Introduced in v1.12.0.
|
||||
- Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. ([\#7155](https://github.com/matrix-org/synapse/issues/7155)). Introduced in v1.12.0rc1.
|
||||
- Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. ([\#7177](https://github.com/matrix-org/synapse/issues/7177)). Introduced in v1.11.0.
|
||||
|
||||
Synapse 1.12.0 (2020-03-23)
|
||||
===========================
|
||||
|
||||
|
|
106
INSTALL.md
106
INSTALL.md
|
@ -2,7 +2,6 @@
|
|||
- [Installing Synapse](#installing-synapse)
|
||||
- [Installing from source](#installing-from-source)
|
||||
- [Platform-Specific Instructions](#platform-specific-instructions)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
- [Prebuilt packages](#prebuilt-packages)
|
||||
- [Setting up Synapse](#setting-up-synapse)
|
||||
- [TLS certificates](#tls-certificates)
|
||||
|
@ -10,6 +9,7 @@
|
|||
- [Registering a user](#registering-a-user)
|
||||
- [Setting up a TURN server](#setting-up-a-turn-server)
|
||||
- [URL previews](#url-previews)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
|
||||
# Choosing your server name
|
||||
|
||||
|
@ -36,7 +36,7 @@ that your email address is probably `user@example.com` rather than
|
|||
System requirements:
|
||||
|
||||
- POSIX-compliant system (tested on Linux & OS X)
|
||||
- Python 3.5, 3.6, 3.7 or 3.8.
|
||||
- Python 3.5.2 or later, up to Python 3.8.
|
||||
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
|
||||
|
||||
Synapse is written in Python but some of the libraries it uses are written in
|
||||
|
@ -70,7 +70,7 @@ pip install -U matrix-synapse
|
|||
```
|
||||
|
||||
Before you can start Synapse, you will need to generate a configuration
|
||||
file. To do this, run (in your virtualenv, as before)::
|
||||
file. To do this, run (in your virtualenv, as before):
|
||||
|
||||
```
|
||||
cd ~/synapse
|
||||
|
@ -84,22 +84,24 @@ python -m synapse.app.homeserver \
|
|||
... substituting an appropriate value for `--server-name`.
|
||||
|
||||
This command will generate you a config file that you can then customise, but it will
|
||||
also generate a set of keys for you. These keys will allow your Home Server to
|
||||
identify itself to other Home Servers, so don't lose or delete them. It would be
|
||||
also generate a set of keys for you. These keys will allow your homeserver to
|
||||
identify itself to other homeserver, so don't lose or delete them. It would be
|
||||
wise to back them up somewhere safe. (If, for whatever reason, you do need to
|
||||
change your Home Server's keys, you may find that other Home Servers have the
|
||||
change your homeserver's keys, you may find that other homeserver have the
|
||||
old key cached. If you update the signing key, you should change the name of the
|
||||
key in the `<server name>.signing.key` file (the second word) to something
|
||||
different. See the
|
||||
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
|
||||
for more information on key management.)
|
||||
for more information on key management).
|
||||
|
||||
To actually run your new homeserver, pick a working directory for Synapse to
|
||||
run (e.g. `~/synapse`), and::
|
||||
run (e.g. `~/synapse`), and:
|
||||
|
||||
cd ~/synapse
|
||||
source env/bin/activate
|
||||
synctl start
|
||||
```
|
||||
cd ~/synapse
|
||||
source env/bin/activate
|
||||
synctl start
|
||||
```
|
||||
|
||||
### Platform-Specific Instructions
|
||||
|
||||
|
@ -110,7 +112,7 @@ Installing prerequisites on Ubuntu or Debian:
|
|||
```
|
||||
sudo apt-get install build-essential python3-dev libffi-dev \
|
||||
python3-pip python3-setuptools sqlite3 \
|
||||
libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev
|
||||
libssl-dev virtualenv libjpeg-dev libxslt1-dev
|
||||
```
|
||||
|
||||
#### ArchLinux
|
||||
|
@ -188,7 +190,7 @@ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
|
|||
There is currently no port for OpenBSD. Additionally, OpenBSD's security
|
||||
settings require a slightly more difficult installation process.
|
||||
|
||||
XXX: I suspect this is out of date.
|
||||
(XXX: I suspect this is out of date)
|
||||
|
||||
1. Create a new directory in `/usr/local` called `_synapse`. Also, create a
|
||||
new user called `_synapse` and set that directory as the new user's home.
|
||||
|
@ -196,7 +198,7 @@ XXX: I suspect this is out of date.
|
|||
write and execute permissions on the same memory space to be run from
|
||||
`/usr/local`.
|
||||
2. `su` to the new `_synapse` user and change to their home directory.
|
||||
3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse`
|
||||
3. Create a new virtualenv: `virtualenv -p python3 ~/.synapse`
|
||||
4. Source the virtualenv configuration located at
|
||||
`/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by
|
||||
using the `.` command, rather than `bash`'s `source`.
|
||||
|
@ -217,45 +219,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
|
|||
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
|
||||
for Windows Server.
|
||||
|
||||
### Troubleshooting Installation
|
||||
|
||||
XXX a bunch of this is no longer relevant.
|
||||
|
||||
Synapse requires pip 8 or later, so if your OS provides too old a version you
|
||||
may need to manually upgrade it::
|
||||
|
||||
sudo pip install --upgrade pip
|
||||
|
||||
Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`.
|
||||
You can fix this by manually upgrading pip and virtualenv::
|
||||
|
||||
sudo pip install --upgrade virtualenv
|
||||
|
||||
You can next rerun `virtualenv -p python3 synapse` to update the virtual env.
|
||||
|
||||
Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.`
|
||||
You can fix this by manually installing ndg-httpsclient::
|
||||
|
||||
pip install --upgrade ndg-httpsclient
|
||||
|
||||
Installing may fail with `mock requires setuptools>=17.1. Aborting installation`.
|
||||
You can fix this by upgrading setuptools::
|
||||
|
||||
pip install --upgrade setuptools
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
|
||||
rm -rf /tmp/pip_install_matrix
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
|
||||
pip install twisted
|
||||
|
||||
## Prebuilt packages
|
||||
|
||||
As an alternative to installing from source, prebuilt packages are available
|
||||
|
@ -314,7 +277,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and
|
|||
it should be possible to install it with simply:
|
||||
|
||||
```
|
||||
sudo apt install matrix-synapse
|
||||
sudo apt install matrix-synapse
|
||||
```
|
||||
|
||||
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
|
||||
|
@ -375,8 +338,10 @@ sudo pip install py-bcrypt
|
|||
|
||||
Synapse can be found in the void repositories as 'synapse':
|
||||
|
||||
xbps-install -Su
|
||||
xbps-install -S synapse
|
||||
```
|
||||
xbps-install -Su
|
||||
xbps-install -S synapse
|
||||
```
|
||||
|
||||
### FreeBSD
|
||||
|
||||
|
@ -420,6 +385,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
|||
resources:
|
||||
- names: [client, federation]
|
||||
```
|
||||
|
||||
* You will also need to uncomment the `tls_certificate_path` and
|
||||
`tls_private_key_path` lines under the `TLS` section. You can either
|
||||
point these settings at an existing certificate and key, or you can
|
||||
|
@ -427,15 +393,15 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
|||
for having Synapse automatically provision and renew federation
|
||||
certificates through ACME can be found at [ACME.md](docs/ACME.md).
|
||||
Note that, as pointed out in that document, this feature will not
|
||||
work with installs set up after November 2019.
|
||||
|
||||
work with installs set up after November 2019.
|
||||
|
||||
If you are using your own certificate, be sure to use a `.pem` file that
|
||||
includes the full certificate chain including any intermediate certificates
|
||||
(for instance, if using certbot, use `fullchain.pem` as your certificate, not
|
||||
`cert.pem`).
|
||||
|
||||
For a more detailed guide to configuring your server for federation, see
|
||||
[federate.md](docs/federate.md)
|
||||
[federate.md](docs/federate.md).
|
||||
|
||||
|
||||
## Email
|
||||
|
@ -482,7 +448,7 @@ on your server even if `enable_registration` is `false`.
|
|||
## Setting up a TURN server
|
||||
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
|
||||
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
|
||||
|
||||
## URL previews
|
||||
|
||||
|
@ -491,10 +457,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter
|
|||
and explicitly specify the IP ranges that Synapse is not allowed to spider for
|
||||
previewing in the `url_preview_ip_range_blacklist` configuration parameter.
|
||||
This is critical from a security perspective to stop arbitrary Matrix users
|
||||
spidering 'internal' URLs on your network. At the very least we recommend that
|
||||
spidering 'internal' URLs on your network. At the very least we recommend that
|
||||
your loopback and RFC1918 IP addresses are blacklisted.
|
||||
|
||||
This also requires the optional lxml and netaddr python dependencies to be
|
||||
installed. This in turn requires the libxml2 library to be available - on
|
||||
This also requires the optional `lxml` and `netaddr` python dependencies to be
|
||||
installed. This in turn requires the `libxml2` library to be available - on
|
||||
Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
|
||||
your OS.
|
||||
|
||||
# Troubleshooting Installation
|
||||
|
||||
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.:
|
||||
|
||||
```
|
||||
pip install twisted
|
||||
```
|
||||
|
||||
If you have any other problems, feel free to ask in
|
||||
[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org).
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak.
|
|
@ -0,0 +1 @@
|
|||
Fix missing field `default` when fetching user-defined push rules.
|
|
@ -0,0 +1 @@
|
|||
Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`.
|
|
@ -0,0 +1 @@
|
|||
Transfer alias mappings on room upgrade.
|
|
@ -0,0 +1 @@
|
|||
Move catchup of replication streams logic to worker.
|
|
@ -0,0 +1 @@
|
|||
Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users.
|
|
@ -0,0 +1 @@
|
|||
Ensure that a user inteactive authentication session is tied to a single request.
|
|
@ -0,0 +1 @@
|
|||
Add options to prevent users from changing their profile or associated 3PIDs.
|
|
@ -0,0 +1 @@
|
|||
Support SSO in the user interactive authentication workflow.
|
|
@ -0,0 +1 @@
|
|||
Allow server admins to define and enforce a password policy (MSC2000).
|
|
@ -0,0 +1 @@
|
|||
Update postgres docs with login troubleshooting information.
|
|
@ -0,0 +1 @@
|
|||
Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.
|
|
@ -0,0 +1 @@
|
|||
Refactored the CAS authentication logic to a separate class.
|
|
@ -0,0 +1 @@
|
|||
Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`.
|
|
@ -0,0 +1 @@
|
|||
Clean up INSTALL.md a bit.
|
|
@ -0,0 +1 @@
|
|||
Add documentation for running a local CAS server for testing.
|
|
@ -0,0 +1 @@
|
|||
Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param.
|
|
@ -0,0 +1 @@
|
|||
Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response.
|
|
@ -0,0 +1 @@
|
|||
Improve the support for SSO authentication on the login fallback page.
|
|
@ -0,0 +1 @@
|
|||
Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set.
|
|
@ -0,0 +1 @@
|
|||
Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo.
|
|
@ -0,0 +1 @@
|
|||
Add tests for outbound device pokes.
|
|
@ -0,0 +1 @@
|
|||
Fix device list update stream ids going backward.
|
|
@ -0,0 +1 @@
|
|||
Fix excessive CPU usage by `prune_old_outbound_device_pokes` job.
|
|
@ -0,0 +1 @@
|
|||
Always send users their own device updates.
|
|
@ -0,0 +1 @@
|
|||
Improve README.md by being explicit about public IP recommendation for TURN relaying.
|
|
@ -0,0 +1 @@
|
|||
Fix a small typo in the `metrics_flags` config option.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug which could cause incorrect 'cyclic dependency' error.
|
|
@ -0,0 +1 @@
|
|||
Clean up some LoggingContext code.
|
|
@ -0,0 +1 @@
|
|||
Clean up some LoggingContext code.
|
|
@ -0,0 +1 @@
|
|||
Convert some of synapse.rest.media to async/await.
|
|
@ -0,0 +1 @@
|
|||
Only run one background database update at a time.
|
|
@ -0,0 +1 @@
|
|||
Admin users are no longer required to be in a room to create an alias for it.
|
|
@ -0,0 +1 @@
|
|||
Move catchup of replication streams logic to worker.
|
|
@ -0,0 +1 @@
|
|||
Fix some worker-mode replication handling not being correctly recorded in CPU usage stats.
|
|
@ -1,3 +1,26 @@
|
|||
matrix-synapse-py3 (1.12.3) stable; urgency=medium
|
||||
|
||||
[ Richard van der Hoff ]
|
||||
* Update the Debian build scripts to handle the new installation paths
|
||||
for the support libraries introduced by Pillow 7.1.1.
|
||||
|
||||
[ Synapse Packaging team ]
|
||||
* New synapse release 1.12.3.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Fri, 03 Apr 2020 10:55:03 +0100
|
||||
|
||||
matrix-synapse-py3 (1.12.2) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.12.2.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Mon, 02 Apr 2020 19:02:17 +0000
|
||||
|
||||
matrix-synapse-py3 (1.12.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.12.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Mon, 02 Apr 2020 11:30:47 +0000
|
||||
|
||||
matrix-synapse-py3 (1.12.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.12.0.
|
||||
|
|
|
@ -15,17 +15,38 @@ override_dh_installinit:
|
|||
# we don't really want to strip the symbols from our object files.
|
||||
override_dh_strip:
|
||||
|
||||
# dh_shlibdeps calls dpkg-shlibdeps, which finds all the binary files
|
||||
# (executables and shared libs) in the package, and looks for the shared
|
||||
# libraries that they depend on. It then adds a dependency on the package that
|
||||
# contains that library to the package.
|
||||
#
|
||||
# We make two modifications to that process...
|
||||
#
|
||||
override_dh_shlibdeps:
|
||||
# make the postgres package's dependencies a recommendation
|
||||
# rather than a hard dependency.
|
||||
# Firstly, postgres is not a hard dependency for us, so we want to make
|
||||
# the things that psycopg2 depends on (such as libpq) be
|
||||
# recommendations rather than hard dependencies. We do so by
|
||||
# running dpkg-shlibdeps manually on psycopg2's libs.
|
||||
#
|
||||
find debian/$(PACKAGE_NAME)/ -path '*/site-packages/psycopg2/*.so' | \
|
||||
xargs dpkg-shlibdeps -Tdebian/$(PACKAGE_NAME).substvars \
|
||||
-pshlibs1 -dRecommends
|
||||
|
||||
# all the other dependencies can be normal 'Depends' requirements,
|
||||
# except for PIL's, which is self-contained and which confuses
|
||||
# dpkg-shlibdeps.
|
||||
dh_shlibdeps -X site-packages/PIL/.libs -X site-packages/psycopg2
|
||||
# secondly, we exclude PIL's libraries from the process. They are known
|
||||
# to be self-contained, but they have interdependencies and
|
||||
# dpkg-shlibdeps doesn't know how to resolve them.
|
||||
#
|
||||
# As of Pillow 7.1.0, these libraries are in
|
||||
# site-packages/Pillow.libs. Previously, they were in
|
||||
# site-packages/PIL/.libs.
|
||||
#
|
||||
# (we also need to exclude psycopg2, of course, since we've already
|
||||
# dealt with that.)
|
||||
#
|
||||
dh_shlibdeps \
|
||||
-X site-packages/PIL/.libs \
|
||||
-X site-packages/Pillow.libs \
|
||||
-X site-packages/psycopg2
|
||||
|
||||
override_dh_virtualenv:
|
||||
./debian/build_virtualenv
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Edit Room Membership API
|
||||
|
||||
This API allows an administrator to join an user account with a given `user_id`
|
||||
to a room with a given `room_id_or_alias`. You can only modify the membership of
|
||||
local users. The server administrator must be in the room and have permission to
|
||||
invite users.
|
||||
|
||||
## Parameters
|
||||
|
||||
The following parameters are available:
|
||||
|
||||
* `user_id` - Fully qualified user: for example, `@user:server.com`.
|
||||
* `room_id_or_alias` - The room identifier or alias to join: for example,
|
||||
`!636q39766251:server.com`.
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
POST /_synapse/admin/v1/join/<room_id_or_alias>
|
||||
|
||||
{
|
||||
"user_id": "@user:server.com"
|
||||
}
|
||||
```
|
||||
|
||||
Including an `access_token` of a server admin.
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
{
|
||||
"room_id": "!636q39766251:server.com"
|
||||
}
|
||||
```
|
|
@ -0,0 +1,64 @@
|
|||
# How to test CAS as a developer without a server
|
||||
|
||||
The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an
|
||||
easy to run CAS implementation built on top of Django.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Create a new virtualenv: `python3 -m venv <your virtualenv>`
|
||||
2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate`
|
||||
3. Install Django and django-mama-cas:
|
||||
```
|
||||
python -m pip install "django<3" "django-mama-cas==2.4.0"
|
||||
```
|
||||
4. Create a Django project in the current directory:
|
||||
```
|
||||
django-admin startproject cas_test .
|
||||
```
|
||||
5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas
|
||||
6. Setup the SQLite database: `python manage.py migrate`
|
||||
7. Create a user:
|
||||
```
|
||||
python manage.py createsuperuser
|
||||
```
|
||||
1. Use whatever you want as the username and password.
|
||||
2. Leave the other fields blank.
|
||||
8. Use the built-in Django test server to serve the CAS endpoints on port 8000:
|
||||
```
|
||||
python manage.py runserver
|
||||
```
|
||||
|
||||
You should now have a Django project configured to serve CAS authentication with
|
||||
a single user created.
|
||||
|
||||
## Configure Synapse (and Riot) to use CAS
|
||||
|
||||
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
|
||||
running Django test server:
|
||||
```yaml
|
||||
cas_config:
|
||||
enabled: true
|
||||
server_url: "http://localhost:8000"
|
||||
service_url: "http://localhost:8081"
|
||||
#displayname_attribute: name
|
||||
#required_attributes:
|
||||
# name: value
|
||||
```
|
||||
2. Restart Synapse.
|
||||
|
||||
Note that the above configuration assumes the homeserver is running on port 8081
|
||||
and that the CAS server is on port 8000, both on localhost.
|
||||
|
||||
## Testing the configuration
|
||||
|
||||
Then in Riot:
|
||||
|
||||
1. Visit the login page with a Riot pointing at your homeserver.
|
||||
2. Click the Single Sign-On button.
|
||||
3. Login using the credentials created with `createsuperuser`.
|
||||
4. You should be logged in.
|
||||
|
||||
If you want to repeat this process you'll need to manually logout first:
|
||||
|
||||
1. http://localhost:8000/admin/
|
||||
2. Click "logout" in the top right.
|
|
@ -18,9 +18,13 @@ To make Synapse (and therefore Riot) use it:
|
|||
metadata:
|
||||
local: ["samling.xml"]
|
||||
```
|
||||
5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
|
||||
5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`:
|
||||
```yaml
|
||||
public_baseurl: http://localhost:8080/
|
||||
```
|
||||
6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure
|
||||
the dependencies are installed and ready to go.
|
||||
6. Restart Synapse.
|
||||
7. Restart Synapse.
|
||||
|
||||
Then in Riot:
|
||||
|
||||
|
|
|
@ -61,7 +61,33 @@ Note that the PostgreSQL database *must* have the correct encoding set
|
|||
|
||||
You may need to enable password authentication so `synapse_user` can
|
||||
connect to the database. See
|
||||
<https://www.postgresql.org/docs/11/auth-pg-hba-conf.html>.
|
||||
<https://www.postgresql.org/docs/current/auth-pg-hba-conf.html>.
|
||||
|
||||
If you get an error along the lines of `FATAL: Ident authentication failed for
|
||||
user "synapse_user"`, you may need to use an authentication method other than
|
||||
`ident`:
|
||||
|
||||
* If the `synapse_user` user has a password, add the password to the `database:`
|
||||
section of `homeserver.yaml`. Then add the following to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that
|
||||
```
|
||||
|
||||
* If the `synapse_user` user does not have a password, then a password doesn't
|
||||
have to be added to `homeserver.yaml`. But the following does need to be added
|
||||
to `pg_hba.conf`:
|
||||
|
||||
```
|
||||
host synapse synapse_user ::1/128 trust
|
||||
```
|
||||
|
||||
Note that line order matters in `pg_hba.conf`, so make sure that if you do add a
|
||||
new line, it is inserted before:
|
||||
|
||||
```
|
||||
host all all ::1/128 ident
|
||||
```
|
||||
|
||||
### Fixing incorrect `COLLATE` or `CTYPE`
|
||||
|
||||
|
|
|
@ -872,10 +872,6 @@ media_store_path: "DATADIR/media_store"
|
|||
#
|
||||
#enable_registration_captcha: false
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||
#
|
||||
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
|
||||
|
@ -1090,6 +1086,29 @@ account_threepid_delegates:
|
|||
#email: https://example.com # Delegate email sending to example.com
|
||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||
|
||||
# Whether users are allowed to change their displayname after it has
|
||||
# been initially set. Useful when provisioning users based on the
|
||||
# contents of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_displayname: false
|
||||
|
||||
# Whether users are allowed to change their avatar after it has been
|
||||
# initially set. Useful when provisioning users based on the contents
|
||||
# of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_avatar_url: false
|
||||
|
||||
# Whether users can change the 3PIDs associated with their accounts
|
||||
# (email address and msisdn).
|
||||
#
|
||||
# Defaults to 'true'
|
||||
#
|
||||
#enable_3pid_changes: false
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#
|
||||
|
@ -1125,7 +1144,7 @@ account_threepid_delegates:
|
|||
# enabled by default, either for performance reasons or limited use.
|
||||
#
|
||||
metrics_flags:
|
||||
# Publish synapse_federation_known_servers, a g auge of the number of
|
||||
# Publish synapse_federation_known_servers, a gauge of the number of
|
||||
# servers this homeserver knows about, including itself. May cause
|
||||
# performance problems on large homeservers.
|
||||
#
|
||||
|
@ -1425,6 +1444,10 @@ sso:
|
|||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# If public_baseurl is set, then the login fallback page (used by clients
|
||||
# that don't natively support the required login flows) is whitelisted in
|
||||
# addition to any URLs in this list.
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
|
@ -1486,6 +1509,41 @@ password_config:
|
|||
#
|
||||
#pepper: "EVEN_MORE_SECRET"
|
||||
|
||||
# Define and enforce a password policy. Each parameter is optional.
|
||||
# This is an implementation of MSC2000.
|
||||
#
|
||||
policy:
|
||||
# Whether to enforce the password policy.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# Minimum accepted length for a password.
|
||||
# Defaults to 0.
|
||||
#
|
||||
#minimum_length: 15
|
||||
|
||||
# Whether a password must contain at least one digit.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_digit: true
|
||||
|
||||
# Whether a password must contain at least one symbol.
|
||||
# A symbol is any character that's not a number or a letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_symbol: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_lowercase: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_uppercase: true
|
||||
|
||||
|
||||
# Configuration for sending emails from Synapse.
|
||||
#
|
||||
|
|
|
@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
|
|||
'<' worker to master flows):
|
||||
|
||||
> SERVER example.com
|
||||
< REPLICATE events 53
|
||||
< REPLICATE
|
||||
> POSITION events 53
|
||||
> RDATA events 54 ["$foo1:bar.com", ...]
|
||||
> RDATA events 55 ["$foo4:bar.com", ...]
|
||||
|
||||
The example shows the server accepting a new connection and sending its
|
||||
identity with the `SERVER` command, followed by the client asking to
|
||||
subscribe to the `events` stream from the token `53`. The server then
|
||||
periodically sends `RDATA` commands which have the format
|
||||
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
|
||||
defined by the individual streams.
|
||||
The example shows the server accepting a new connection and sending its identity
|
||||
with the `SERVER` command, followed by the client server to respond with the
|
||||
position of all streams. The server then periodically sends `RDATA` commands
|
||||
which have the format `RDATA <stream_name> <token> <row>`, where the format of
|
||||
`<row>` is defined by the individual streams.
|
||||
|
||||
Error reporting happens by either the client or server sending an ERROR
|
||||
command, and usually the connection will be closed.
|
||||
|
@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
|
|||
connect to the server using a tool like netcat. A few things should be
|
||||
noted when manually using the protocol:
|
||||
|
||||
- When subscribing to a stream using `REPLICATE`, the special token
|
||||
`NOW` can be used to get all future updates. The special stream name
|
||||
`ALL` can be used with `NOW` to subscribe to all available streams.
|
||||
- The federation stream is only available if federation sending has
|
||||
been disabled on the main process.
|
||||
- The server will only time connections out that have sent a `PING`
|
||||
|
@ -91,9 +88,7 @@ The client:
|
|||
- Sends a `NAME` command, allowing the server to associate a human
|
||||
friendly name with the connection. This is optional.
|
||||
- Sends a `PING` as above
|
||||
- For each stream the client wishes to subscribe to it sends a
|
||||
`REPLICATE` with the `stream_name` and token it wants to subscribe
|
||||
from.
|
||||
- Sends a `REPLICATE` to get the current position of all streams.
|
||||
- On receipt of a `SERVER` command, checks that the server name
|
||||
matches the expected server name.
|
||||
|
||||
|
@ -140,9 +135,7 @@ the wire:
|
|||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
|
@ -181,9 +174,9 @@ client (C):
|
|||
|
||||
#### POSITION (S)
|
||||
|
||||
The position of the stream has been updated. Sent to the client
|
||||
after all missing updates for a stream have been sent to the client
|
||||
and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed any
|
||||
updates, and if so then fetch them out of band. Sent in response to a
|
||||
REPLICATE command (but can happen at any time).
|
||||
|
||||
#### ERROR (S, C)
|
||||
|
||||
|
@ -199,25 +192,18 @@ client (C):
|
|||
|
||||
#### REPLICATE (C)
|
||||
|
||||
Asks the server to replicate a given stream. The syntax is:
|
||||
|
||||
```
|
||||
REPLICATE <stream_name> <token>
|
||||
```
|
||||
|
||||
Where `<token>` may be either:
|
||||
* a numeric stream_id to stream updates since (exclusive)
|
||||
* `NOW` to stream all subsequent updates.
|
||||
|
||||
The `<stream_name>` is the name of a replication stream to subscribe
|
||||
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
||||
of streams). It can also be `ALL` to subscribe to all known streams,
|
||||
in which case the `<token>` must be set to `NOW`.
|
||||
Asks the server for the current position of all streams.
|
||||
|
||||
#### USER_SYNC (C)
|
||||
|
||||
A user has started or stopped syncing
|
||||
|
||||
#### CLEAR_USER_SYNC (C)
|
||||
|
||||
The server should clear all associated user sync data from the worker.
|
||||
|
||||
This is used when a worker is shutting down.
|
||||
|
||||
#### FEDERATION_ACK (C)
|
||||
|
||||
Acknowledge receipt of some federation data
|
||||
|
|
|
@ -11,6 +11,13 @@ TURN server.
|
|||
|
||||
The following sections describe how to install [coturn](<https://github.com/coturn/coturn>) (which implements the TURN REST API) and integrate it with synapse.
|
||||
|
||||
## Requirements
|
||||
|
||||
For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint with a public IP.
|
||||
|
||||
Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues
|
||||
and to often not work.
|
||||
|
||||
## `coturn` Setup
|
||||
|
||||
### Initial installation
|
||||
|
|
|
@ -36,7 +36,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.12.0"
|
||||
__version__ = "1.12.3"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
|
|
@ -61,6 +61,7 @@ class LoginType(object):
|
|||
MSISDN = "m.login.msisdn"
|
||||
RECAPTCHA = "m.login.recaptcha"
|
||||
TERMS = "m.login.terms"
|
||||
SSO = "org.matrix.login.sso"
|
||||
DUMMY = "m.login.dummy"
|
||||
|
||||
# Only for C/S API v1
|
||||
|
|
|
@ -64,6 +64,13 @@ class Codes(object):
|
|||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
|
||||
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
|
||||
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
|
||||
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
|
||||
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
|
||||
PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
|
||||
PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
|
||||
PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
|
||||
WEAK_PASSWORD = "M_WEAK_PASSWORD"
|
||||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||
BAD_ALIAS = "M_BAD_ALIAS"
|
||||
|
@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
|
|||
return cs_error(self.msg, self.errcode, room_version=self._room_version)
|
||||
|
||||
|
||||
class PasswordRefusedError(SynapseError):
|
||||
"""A password has been refused, either during password reset/change or registration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
msg="This password doesn't comply with the server's policy",
|
||||
errcode=Codes.WEAK_PASSWORD,
|
||||
):
|
||||
super(PasswordRefusedError, self).__init__(
|
||||
code=400, msg=msg, errcode=errcode,
|
||||
)
|
||||
|
||||
|
||||
class RequestSendFailed(RuntimeError):
|
||||
"""Sending a HTTP request over federation failed due to not being able to
|
||||
talk to the remote server for some reason.
|
||||
|
|
|
@ -42,7 +42,7 @@ from synapse.handlers.presence import PresenceHandler, get_interested_parties
|
|||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import LoggingContext, run_in_background
|
||||
from synapse.logging.context import LoggingContext
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
|
||||
|
@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
|||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.commands import ClearUserSyncsCommand
|
||||
from synapse.replication.tcp.streams import (
|
||||
AccountDataStream,
|
||||
DeviceListsStream,
|
||||
|
@ -124,7 +125,6 @@ from synapse.types import ReadReceipt
|
|||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
logger = logging.getLogger("synapse.app.generic_worker")
|
||||
|
@ -233,6 +233,7 @@ class GenericWorkerPresence(object):
|
|||
self.user_to_num_current_syncs = {}
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
|
@ -245,13 +246,24 @@ class GenericWorkerPresence(object):
|
|||
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
||||
)
|
||||
|
||||
self.process_id = random_string(16)
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
run_as_background_process,
|
||||
"generic_presence.on_shutdown",
|
||||
self._on_shutdown,
|
||||
)
|
||||
|
||||
def _on_shutdown(self):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_command(
|
||||
ClearUserSyncsCommand(self.instance_id)
|
||||
)
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_user_sync(
|
||||
user_id, is_syncing, last_sync_ms
|
||||
self.instance_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
def mark_as_coming_online(self, user_id):
|
||||
|
@ -401,6 +413,9 @@ class GenericWorkerTyping(object):
|
|||
self._room_serials[row.room_id] = token
|
||||
self._room_typing[row.room_id] = row.user_ids
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class GenericWorkerSlavedStore(
|
||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
|
@ -620,7 +635,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
|||
await super(GenericWorkerReplicationHandler, self).on_rdata(
|
||||
stream_name, token, rows
|
||||
)
|
||||
run_in_background(self.process_and_notify, stream_name, token, rows)
|
||||
await self.process_and_notify(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
||||
|
@ -635,7 +650,9 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
|||
async def process_and_notify(self, stream_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||
await self.send_handler.process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
|
@ -767,12 +784,12 @@ class FederationSenderHandler(object):
|
|||
def stream_positions(self):
|
||||
return {"federation": self.federation_position}
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
async def process_replication_rows(self, stream_name, token, rows):
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
if stream_name == "federation":
|
||||
send_queue.process_rows_for_federation(self.federation_sender, rows)
|
||||
run_in_background(self.update_token, token)
|
||||
await self.update_token(token)
|
||||
|
||||
# We also need to poke the federation sender when new events happen
|
||||
elif stream_name == "events":
|
||||
|
@ -780,9 +797,7 @@ class FederationSenderHandler(object):
|
|||
|
||||
# ... and when new receipts happen
|
||||
elif stream_name == ReceiptsStream.NAME:
|
||||
run_as_background_process(
|
||||
"process_receipts_for_federation", self._on_new_receipts, rows
|
||||
)
|
||||
await self._on_new_receipts(rows)
|
||||
|
||||
# ... as well as device updates and messages
|
||||
elif stream_name == DeviceListsStream.NAME:
|
||||
|
|
|
@ -24,7 +24,6 @@ class CaptchaConfig(Config):
|
|||
self.enable_registration_captcha = config.get(
|
||||
"enable_registration_captcha", False
|
||||
)
|
||||
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
||||
self.recaptcha_siteverify_api = config.get(
|
||||
"recaptcha_siteverify_api",
|
||||
"https://www.recaptcha.net/recaptcha/api/siteverify",
|
||||
|
@ -49,10 +48,6 @@ class CaptchaConfig(Config):
|
|||
#
|
||||
#enable_registration_captcha: false
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||
#
|
||||
#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
|
||||
|
|
|
@ -20,6 +20,11 @@ from synapse.config._base import Config, ConfigError
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_SQLITE_DATABASE_PATH_WARNING = """\
|
||||
Ignoring 'database_path' setting: not using a sqlite3 database.
|
||||
--------------------------------------------------------------------------------
|
||||
"""
|
||||
|
||||
DEFAULT_CONFIG = """\
|
||||
## Database ##
|
||||
|
||||
|
@ -105,6 +110,11 @@ class DatabaseConnectionConfig:
|
|||
class DatabaseConfig(Config):
|
||||
section = "database"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.databases = []
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
|
||||
|
||||
|
@ -125,12 +135,13 @@ class DatabaseConfig(Config):
|
|||
|
||||
multi_database_config = config.get("databases")
|
||||
database_config = config.get("database")
|
||||
database_path = config.get("database_path")
|
||||
|
||||
if multi_database_config and database_config:
|
||||
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
|
||||
|
||||
if multi_database_config:
|
||||
if config.get("database_path"):
|
||||
if database_path:
|
||||
raise ConfigError("Can't specify 'database_path' with 'databases'")
|
||||
|
||||
self.databases = [
|
||||
|
@ -138,13 +149,17 @@ class DatabaseConfig(Config):
|
|||
for name, db_conf in multi_database_config.items()
|
||||
]
|
||||
|
||||
else:
|
||||
if database_config is None:
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
|
||||
if database_config:
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
|
||||
self.set_databasepath(config.get("database_path"))
|
||||
if database_path:
|
||||
if self.databases and self.databases[0].name != "sqlite3":
|
||||
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||
return
|
||||
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
self.set_databasepath(database_path)
|
||||
|
||||
def generate_config_section(self, data_dir_path, **kwargs):
|
||||
return DEFAULT_CONFIG % {
|
||||
|
@ -152,27 +167,37 @@ class DatabaseConfig(Config):
|
|||
}
|
||||
|
||||
def read_arguments(self, args):
|
||||
self.set_databasepath(args.database_path)
|
||||
"""
|
||||
Cases for the cli input:
|
||||
- If no databases are configured and no database_path is set, raise.
|
||||
- No databases and only database_path available ==> sqlite3 db.
|
||||
- If there are multiple databases and a database_path raise an error.
|
||||
- If the database set in the config file is sqlite then
|
||||
overwrite with the command line argument.
|
||||
"""
|
||||
|
||||
if args.database_path is None:
|
||||
if not self.databases:
|
||||
raise ConfigError("No database config provided")
|
||||
return
|
||||
|
||||
if len(self.databases) == 0:
|
||||
database_config = {"name": "sqlite3", "args": {}}
|
||||
self.databases = [DatabaseConnectionConfig("master", database_config)]
|
||||
self.set_databasepath(args.database_path)
|
||||
return
|
||||
|
||||
if self.get_single_database().name == "sqlite3":
|
||||
self.set_databasepath(args.database_path)
|
||||
else:
|
||||
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
|
||||
|
||||
def set_databasepath(self, database_path):
|
||||
if database_path is None:
|
||||
return
|
||||
|
||||
if database_path != ":memory:":
|
||||
database_path = self.abspath(database_path)
|
||||
|
||||
# We only support setting a database path if we have a single sqlite3
|
||||
# database.
|
||||
if len(self.databases) != 1:
|
||||
raise ConfigError("Cannot specify 'database_path' with multiple databases")
|
||||
|
||||
database = self.get_single_database()
|
||||
if database.config["name"] != "sqlite3":
|
||||
# We don't raise here as we haven't done so before for this case.
|
||||
logger.warn("Ignoring 'database_path' for non-sqlite3 database")
|
||||
return
|
||||
|
||||
database.config["args"]["database"] = database_path
|
||||
self.databases[0].config["args"]["database"] = database_path
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
|
@ -187,7 +212,7 @@ class DatabaseConfig(Config):
|
|||
def get_single_database(self) -> DatabaseConnectionConfig:
|
||||
"""Returns the database if there is only one, useful for e.g. tests
|
||||
"""
|
||||
if len(self.databases) != 1:
|
||||
if not self.databases:
|
||||
raise Exception("More than one database exists")
|
||||
|
||||
return self.databases[0]
|
||||
|
|
|
@ -86,7 +86,7 @@ class MetricsConfig(Config):
|
|||
# enabled by default, either for performance reasons or limited use.
|
||||
#
|
||||
metrics_flags:
|
||||
# Publish synapse_federation_known_servers, a g auge of the number of
|
||||
# Publish synapse_federation_known_servers, a gauge of the number of
|
||||
# servers this homeserver knows about, including itself. May cause
|
||||
# performance problems on large homeservers.
|
||||
#
|
||||
|
|
|
@ -31,6 +31,10 @@ class PasswordConfig(Config):
|
|||
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
# Password policy
|
||||
self.password_policy = password_config.get("policy") or {}
|
||||
self.password_policy_enabled = self.password_policy.get("enabled", False)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
password_config:
|
||||
|
@ -48,4 +52,39 @@ class PasswordConfig(Config):
|
|||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||
#
|
||||
#pepper: "EVEN_MORE_SECRET"
|
||||
|
||||
# Define and enforce a password policy. Each parameter is optional.
|
||||
# This is an implementation of MSC2000.
|
||||
#
|
||||
policy:
|
||||
# Whether to enforce the password policy.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# Minimum accepted length for a password.
|
||||
# Defaults to 0.
|
||||
#
|
||||
#minimum_length: 15
|
||||
|
||||
# Whether a password must contain at least one digit.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_digit: true
|
||||
|
||||
# Whether a password must contain at least one symbol.
|
||||
# A symbol is any character that's not a number or a letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_symbol: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_lowercase: true
|
||||
|
||||
# Whether a password must contain at least one lowercase letter.
|
||||
# Defaults to 'false'.
|
||||
#
|
||||
#require_uppercase: true
|
||||
"""
|
||||
|
|
|
@ -129,6 +129,10 @@ class RegistrationConfig(Config):
|
|||
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
|
||||
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
|
||||
|
||||
self.enable_set_displayname = config.get("enable_set_displayname", True)
|
||||
self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
|
||||
self.enable_3pid_changes = config.get("enable_3pid_changes", True)
|
||||
|
||||
self.disable_msisdn_registration = config.get(
|
||||
"disable_msisdn_registration", False
|
||||
)
|
||||
|
@ -330,6 +334,29 @@ class RegistrationConfig(Config):
|
|||
#email: https://example.com # Delegate email sending to example.com
|
||||
#msisdn: http://localhost:8090 # Delegate SMS sending to this local process
|
||||
|
||||
# Whether users are allowed to change their displayname after it has
|
||||
# been initially set. Useful when provisioning users based on the
|
||||
# contents of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_displayname: false
|
||||
|
||||
# Whether users are allowed to change their avatar after it has been
|
||||
# initially set. Useful when provisioning users based on the contents
|
||||
# of a third-party directory.
|
||||
#
|
||||
# Does not apply to server administrators. Defaults to 'true'
|
||||
#
|
||||
#enable_set_avatar_url: false
|
||||
|
||||
# Whether users can change the 3PIDs associated with their accounts
|
||||
# (email address and msisdn).
|
||||
#
|
||||
# Defaults to 'true'
|
||||
#
|
||||
#enable_3pid_changes: false
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#
|
||||
|
|
|
@ -39,6 +39,17 @@ class SSOConfig(Config):
|
|||
|
||||
self.sso_client_whitelist = sso_config.get("client_whitelist") or []
|
||||
|
||||
# Attempt to also whitelist the server's login fallback, since that fallback sets
|
||||
# the redirect URL to itself (so it can process the login token then return
|
||||
# gracefully to the client). This would make it pointless to ask the user for
|
||||
# confirmation, since the URL the confirmation page would be showing wouldn't be
|
||||
# the client's.
|
||||
# public_baseurl is an optional setting, so we only add the fallback's URL to the
|
||||
# list if it's provided (because we can't figure out what that URL is otherwise).
|
||||
if self.public_baseurl:
|
||||
login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
|
||||
self.sso_client_whitelist.append(login_fallback_url)
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# Additional settings to use with single-sign on systems such as SAML2 and CAS.
|
||||
|
@ -54,6 +65,10 @@ class SSOConfig(Config):
|
|||
# phishing attacks from evil.site. To avoid this, include a slash after the
|
||||
# hostname: "https://my.client/".
|
||||
#
|
||||
# If public_baseurl is set, then the login fallback page (used by clients
|
||||
# that don't natively support the required login flows) is whitelisted in
|
||||
# addition to any URLs in this list.
|
||||
#
|
||||
# By default, this list is empty.
|
||||
#
|
||||
#client_whitelist:
|
||||
|
|
|
@ -499,4 +499,13 @@ class FederationSender(object):
|
|||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return 0
|
||||
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return []
|
||||
|
|
|
@ -53,6 +53,31 @@ from ._base import BaseHandler
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SUCCESS_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Success!</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
if (window.onAuthDone) {
|
||||
window.onAuthDone();
|
||||
} else if (window.opener && window.opener.postMessage) {
|
||||
window.opener.postMessage("authDone", "*");
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
|
@ -91,6 +116,7 @@ class AuthHandler(BaseHandler):
|
|||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._saml2_enabled = hs.config.saml2_enabled
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
|
@ -106,6 +132,13 @@ class AuthHandler(BaseHandler):
|
|||
if t not in login_types:
|
||||
login_types.append(t)
|
||||
self._supported_login_types = login_types
|
||||
# Login types and UI Auth types have a heavy overlap, but are not
|
||||
# necessarily identical. Login types have SSO (and other login types)
|
||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||
ui_auth_types = login_types.copy()
|
||||
if self._saml2_enabled:
|
||||
ui_auth_types.append(LoginType.SSO)
|
||||
self._supported_ui_auth_types = ui_auth_types
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
|
@ -113,10 +146,21 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Load the SSO redirect confirmation page HTML template
|
||||
# Load the SSO HTML templates.
|
||||
|
||||
# The following template is shown to the user during a client login via SSO,
|
||||
# after the SSO completes and before redirecting them back to their client.
|
||||
# It notifies the user they are about to give access to their matrix account
|
||||
# to the client.
|
||||
self._sso_redirect_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
|
||||
)[0]
|
||||
# The following template is shown during user interactive authentication
|
||||
# in the fallback auth scenario. It notifies the user that they are
|
||||
# authenticating for an operation to occur on their account.
|
||||
self._sso_auth_confirm_template = load_jinja2_templates(
|
||||
hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
|
||||
)[0]
|
||||
|
||||
self._server_name = hs.config.server_name
|
||||
|
||||
|
@ -125,7 +169,12 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def validate_user_via_ui_auth(
|
||||
self, requester: Requester, request_body: Dict[str, Any], clientip: str
|
||||
self,
|
||||
requester: Requester,
|
||||
request: SynapseRequest,
|
||||
request_body: Dict[str, Any],
|
||||
clientip: str,
|
||||
description: str,
|
||||
):
|
||||
"""
|
||||
Checks that the user is who they claim to be, via a UI auth.
|
||||
|
@ -137,10 +186,15 @@ class AuthHandler(BaseHandler):
|
|||
Args:
|
||||
requester: The user, as given by the access token
|
||||
|
||||
request: The request sent by the client.
|
||||
|
||||
request_body: The body of the request sent by the client
|
||||
|
||||
clientip: The IP address of the client.
|
||||
|
||||
description: A human readable string to be displayed to the user that
|
||||
describes the operation happening on their account.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict]: the parameters for this request (which may
|
||||
have been given only in a previous call).
|
||||
|
@ -169,10 +223,12 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [[login_type] for login_type in self._supported_login_types]
|
||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
||||
|
||||
try:
|
||||
result, params, _ = yield self.check_auth(flows, request_body, clientip)
|
||||
result, params, _ = yield self.check_auth(
|
||||
flows, request, request_body, clientip, description
|
||||
)
|
||||
except LoginError:
|
||||
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
|
||||
self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||
|
@ -185,7 +241,7 @@ class AuthHandler(BaseHandler):
|
|||
raise
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_login_types:
|
||||
for login_type in self._supported_ui_auth_types:
|
||||
if login_type not in result:
|
||||
continue
|
||||
|
||||
|
@ -211,7 +267,12 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(
|
||||
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
|
||||
self,
|
||||
flows: List[List[str]],
|
||||
request: SynapseRequest,
|
||||
clientdict: Dict[str, Any],
|
||||
clientip: str,
|
||||
description: str,
|
||||
):
|
||||
"""
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
|
@ -231,11 +292,16 @@ class AuthHandler(BaseHandler):
|
|||
strings representing auth-types. At least one full
|
||||
flow must be completed in order for auth to be successful.
|
||||
|
||||
request: The request sent by the client.
|
||||
|
||||
clientdict: The dictionary from the client root level, not the
|
||||
'auth' key: this method prompts for auth if none is sent.
|
||||
|
||||
clientip: The IP address of the client.
|
||||
|
||||
description: A human readable string to be displayed to the user that
|
||||
describes the operation happening on their account.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||
(creds, params, session_id).
|
||||
|
@ -270,13 +336,33 @@ class AuthHandler(BaseHandler):
|
|||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a homeserver.
|
||||
# Revisit: Assumimg the REST APIs do sensible validation, the data
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbintrary.
|
||||
session["clientdict"] = clientdict
|
||||
self._save_session(session)
|
||||
elif "clientdict" in session:
|
||||
clientdict = session["clientdict"]
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if "ui_auth" not in session:
|
||||
session["ui_auth"] = comparator
|
||||
self._save_session(session)
|
||||
elif session["ui_auth"] != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
)
|
||||
|
||||
# Add a human readable description to the session.
|
||||
if "description" not in session:
|
||||
session["description"] = description
|
||||
self._save_session(session)
|
||||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session)
|
||||
|
@ -322,6 +408,7 @@ class AuthHandler(BaseHandler):
|
|||
creds,
|
||||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session["id"]
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
|
@ -962,6 +1049,56 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return defer.succeed(False)
|
||||
|
||||
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
"""
|
||||
Get the HTML for the SSO redirect confirmation page.
|
||||
|
||||
Args:
|
||||
redirect_url: The URL to redirect to the SSO provider.
|
||||
session_id: The user interactive authentication session ID.
|
||||
|
||||
Returns:
|
||||
The HTML to render.
|
||||
"""
|
||||
session = self._get_session_info(session_id)
|
||||
# Get the human readable operation of what is occurring, falling back to
|
||||
# a generic message if it isn't available for some reason.
|
||||
description = session.get("description", "modify your account")
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=description, redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
def complete_sso_ui_auth(
|
||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
registered_user_id: The registered user ID to complete SSO login for.
|
||||
request: The request to complete.
|
||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||
process.
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
sess = self._get_session_info(session_id)
|
||||
if "creds" not in sess:
|
||||
sess["creds"] = {}
|
||||
creds = sess["creds"]
|
||||
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
# that the account be modified is also the person who logged in.
|
||||
creds[LoginType.SSO] = registered_user_id
|
||||
self._save_session(sess)
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = SUCCESS_TEMPLATE.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
finish_request(request)
|
||||
|
||||
def complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import AnyStr, Dict, Optional, Tuple
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CasHandler:
|
||||
"""
|
||||
Utility class for to handle the response from a CAS SSO service.
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
self._cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self._cas_required_attributes = hs.config.cas_required_attributes
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
|
||||
def _build_service_param(self, client_redirect_url: AnyStr) -> str:
|
||||
return "%s%s?%s" % (
|
||||
self._cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
|
||||
)
|
||||
|
||||
async def _handle_cas_response(
|
||||
self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
|
||||
) -> None:
|
||||
"""
|
||||
Retrieves the user and display name from the CAS response and continues with the authentication.
|
||||
|
||||
Args:
|
||||
request: The original client request.
|
||||
cas_response_body: The response from the CAS server.
|
||||
client_redirect_url: The URl to redirect the client to when
|
||||
everything is done.
|
||||
"""
|
||||
user, attributes = self._parse_cas_response(cas_response_body)
|
||||
displayname = attributes.pop(self._cas_displayname_attribute, None)
|
||||
|
||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
await self._on_successful_auth(user, request, client_redirect_url, displayname)
|
||||
|
||||
def _parse_cas_response(
|
||||
self, cas_response_body: str
|
||||
) -> Tuple[str, Dict[str, Optional[str]]]:
|
||||
"""
|
||||
Retrieve the user and other parameters from the CAS response.
|
||||
|
||||
Args:
|
||||
cas_response_body: The response from the CAS query.
|
||||
|
||||
Returns:
|
||||
A tuple of the user and a mapping of other attributes.
|
||||
"""
|
||||
user = None
|
||||
attributes = {}
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.exception("Error parsing CAS response")
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
async def _on_successful_auth(
|
||||
self,
|
||||
username: str,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
user_display_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
Registers the user if necessary, and then returns a redirect (with
|
||||
a login token) to the client.
|
||||
|
||||
Args:
|
||||
username: the remote user id. We'll map this onto
|
||||
something sane for a MXID localpath.
|
||||
|
||||
request: the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
client_redirect_url: the redirect_url the client gave us when
|
||||
it first started the process.
|
||||
|
||||
user_display_name: if set, and we have to register a new user,
|
||||
we will set their displayname to this.
|
||||
"""
|
||||
localpart = map_username_to_mxid_localpart(username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
|
||||
"""
|
||||
Generates a URL to the CAS server where the client should be redirected.
|
||||
|
||||
Args:
|
||||
client_redirect_url: The final URL the client should go to after the
|
||||
user has negotiated SSO.
|
||||
|
||||
Returns:
|
||||
The URL to redirect to.
|
||||
"""
|
||||
args = urllib.parse.urlencode(
|
||||
{"service": self._build_service_param(client_redirect_url)}
|
||||
)
|
||||
|
||||
return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
|
||||
|
||||
async def handle_ticket_request(
|
||||
self, request: SynapseRequest, client_redirect_url: str, ticket: str
|
||||
) -> None:
|
||||
"""
|
||||
Validates a CAS ticket sent by the client for login/registration.
|
||||
|
||||
On a successful request, writes a redirect to the request.
|
||||
"""
|
||||
uri = self._cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": ticket,
|
||||
"service": self._build_service_param(client_redirect_url),
|
||||
}
|
||||
try:
|
||||
body = await self._http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
|
||||
await self._handle_cas_response(request, body, client_redirect_url)
|
|
@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
|
||||
changed = yield self.store.get_users_whose_devices_changed(
|
||||
from_token.device_list_key, users_who_share_room
|
||||
from_token.device_list_key, tracked_users
|
||||
)
|
||||
|
||||
# Then work out if any users have since joined
|
||||
|
@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
|
||||
# specify the user ID too since the user should always get their own device list
|
||||
# updates, even if they aren't in any rooms.
|
||||
yield self.notifier.on_new_event(
|
||||
"device_list_key", position, users=[user_id], rooms=room_ids
|
||||
)
|
||||
|
||||
if hosts:
|
||||
logger.info(
|
||||
|
|
|
@ -127,7 +127,11 @@ class DirectoryHandler(BaseHandler):
|
|||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
else:
|
||||
if self.require_membership and check_membership:
|
||||
# Server admins are not subject to the same constraints as normal
|
||||
# users when creating an alias (e.g. being in the room).
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if (self.require_membership and check_membership) and not is_admin:
|
||||
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
|
||||
if room_id not in rooms_for_user:
|
||||
raise AuthError(
|
||||
|
|
|
@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event
|
|||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
nested_logging_context,
|
||||
|
@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
|
|||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
from synapse.util.stringutils import shortstr
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -93,27 +93,6 @@ class _NewEventInfo:
|
|||
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
|
||||
|
||||
|
||||
def shortstr(iterable, maxitems=5):
|
||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||
containing those items.
|
||||
|
||||
Otherwise, return the stringification of a a list with the first maxitems items,
|
||||
followed by "...".
|
||||
|
||||
Args:
|
||||
iterable (Iterable): iterable to truncate
|
||||
maxitems (int): number of items to return before truncating
|
||||
|
||||
Returns:
|
||||
unicode
|
||||
"""
|
||||
|
||||
items = list(itertools.islice(iterable, maxitems + 1))
|
||||
if len(items) <= maxitems:
|
||||
return str(items)
|
||||
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
|
||||
|
||||
|
||||
class FederationHandler(BaseHandler):
|
||||
"""Handles events that originated from federation.
|
||||
Responsible for:
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from synapse.api.errors import Codes, PasswordRefusedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordPolicyHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.policy = hs.config.password_policy
|
||||
self.enabled = hs.config.password_policy_enabled
|
||||
|
||||
# Regexps for the spec'd policy parameters.
|
||||
self.regexp_digit = re.compile("[0-9]")
|
||||
self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
|
||||
self.regexp_uppercase = re.compile("[A-Z]")
|
||||
self.regexp_lowercase = re.compile("[a-z]")
|
||||
|
||||
def validate_password(self, password):
|
||||
"""Checks whether a given password complies with the server's policy.
|
||||
|
||||
Args:
|
||||
password (str): The password to check against the server's policy.
|
||||
|
||||
Raises:
|
||||
PasswordRefusedError: The password doesn't comply with the server's policy.
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
minimum_accepted_length = self.policy.get("minimum_length", 0)
|
||||
if len(password) < minimum_accepted_length:
|
||||
raise PasswordRefusedError(
|
||||
msg=(
|
||||
"The password must be at least %d characters long"
|
||||
% minimum_accepted_length
|
||||
),
|
||||
errcode=Codes.PASSWORD_TOO_SHORT,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_digit", False)
|
||||
and self.regexp_digit.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one digit",
|
||||
errcode=Codes.PASSWORD_NO_DIGIT,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_symbol", False)
|
||||
and self.regexp_symbol.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one symbol",
|
||||
errcode=Codes.PASSWORD_NO_SYMBOL,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_uppercase", False)
|
||||
and self.regexp_uppercase.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one uppercase letter",
|
||||
errcode=Codes.PASSWORD_NO_UPPERCASE,
|
||||
)
|
||||
|
||||
if (
|
||||
self.policy.get("require_lowercase", False)
|
||||
and self.regexp_lowercase.search(password) is None
|
||||
):
|
||||
raise PasswordRefusedError(
|
||||
msg="The password must include at least one lowercase letter",
|
||||
errcode=Codes.PASSWORD_NO_LOWERCASE,
|
||||
)
|
|
@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler):
|
|||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_displayname:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Changing display name is disabled on this server",
|
||||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
|
||||
raise SynapseError(
|
||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||
|
@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler):
|
|||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_avatar_url:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
|
||||
raise SynapseError(
|
||||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||
|
|
|
@ -519,6 +519,9 @@ class RoomMemberHandler(object):
|
|||
yield self.store.set_room_is_public(old_room_id, False)
|
||||
yield self.store.set_room_is_public(room_id, True)
|
||||
|
||||
# Transfer alias mappings in the room directory
|
||||
yield self.store.update_aliases_for_room(old_room_id, room_id)
|
||||
|
||||
# Check if any groups we own contain the predecessor room
|
||||
local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
|
||||
for group_id in local_group_ids:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import attr
|
||||
import saml2
|
||||
|
@ -26,6 +26,7 @@ from synapse.config import ConfigError
|
|||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.module_api.errors import RedirectException
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
|
@ -43,11 +44,15 @@ class Saml2SessionData:
|
|||
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
# The user interactive authentication session ID associated with this SAML
|
||||
# session (or None if this SAML session is for an initial login).
|
||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
class SamlHandler:
|
||||
def __init__(self, hs):
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
|
@ -76,12 +81,14 @@ class SamlHandler:
|
|||
|
||||
self._error_html_content = hs.config.saml2_error_html_content
|
||||
|
||||
def handle_redirect_request(self, client_redirect_url):
|
||||
def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client to when everything is done
|
||||
ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
|
||||
None if this is a login).
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
|
@ -91,7 +98,9 @@ class SamlHandler:
|
|||
)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
|
||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
|
||||
for key, value in info["headers"]:
|
||||
if key == "Location":
|
||||
|
@ -118,7 +127,12 @@ class SamlHandler:
|
|||
self.expire_sessions()
|
||||
|
||||
try:
|
||||
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
|
||||
user_id, current_session = await self._map_saml_response_to_user(
|
||||
resp_bytes, relay_state
|
||||
)
|
||||
except RedirectException:
|
||||
# Raise the exception as per the wishes of the SAML module response
|
||||
raise
|
||||
except Exception as e:
|
||||
# If decoding the response or mapping it to a user failed, then log the
|
||||
# error and tell the user that something went wrong.
|
||||
|
@ -133,9 +147,28 @@ class SamlHandler:
|
|||
finish_request(request)
|
||||
return
|
||||
|
||||
self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
# Complete the interactive auth session or the login.
|
||||
if current_session and current_session.ui_auth_session_id:
|
||||
self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, current_session.ui_auth_session_id, request
|
||||
)
|
||||
|
||||
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
|
||||
else:
|
||||
self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(
|
||||
self, resp_bytes: str, client_redirect_url: str
|
||||
) -> Tuple[str, Optional[Saml2SessionData]]:
|
||||
"""
|
||||
Given a sample response, retrieve the cached session and user for it.
|
||||
|
||||
Args:
|
||||
resp_bytes: The SAML response.
|
||||
client_redirect_url: The redirect URL passed in by the client.
|
||||
|
||||
Returns:
|
||||
Tuple of the user ID and SAML session associated with this response.
|
||||
"""
|
||||
try:
|
||||
saml2_auth = self._saml_client.parse_authn_request_response(
|
||||
resp_bytes,
|
||||
|
@ -163,7 +196,9 @@ class SamlHandler:
|
|||
|
||||
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
|
||||
|
||||
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
|
||||
current_session = self._outstanding_requests_dict.pop(
|
||||
saml2_auth.in_response_to, None
|
||||
)
|
||||
|
||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
||||
saml2_auth, client_redirect_url
|
||||
|
@ -184,7 +219,7 @@ class SamlHandler:
|
|||
)
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
return registered_user_id, current_session
|
||||
|
||||
# backwards-compatibility hack: see if there is an existing user with a
|
||||
# suitable mapping from the uid
|
||||
|
@ -209,7 +244,7 @@ class SamlHandler:
|
|||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
return registered_user_id, current_session
|
||||
|
||||
# Map saml response to user attributes using the configured mapping provider
|
||||
for i in range(1000):
|
||||
|
@ -256,7 +291,7 @@ class SamlHandler:
|
|||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
return registered_user_id, current_session
|
||||
|
||||
def expire_sessions(self):
|
||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
||||
|
|
|
@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
|
|||
super(SetPasswordHandler, self).__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._password_policy_handler = hs.get_password_policy_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(
|
||||
|
@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
|
|||
if not self.hs.config.password_localdb_enabled:
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
self._password_policy_handler.validate_password(new_password)
|
||||
password_hash = yield self._auth_handler.hash(new_password)
|
||||
|
||||
try:
|
||||
|
|
|
@ -1143,9 +1143,14 @@ class SyncHandler(object):
|
|||
user_id
|
||||
)
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
|
||||
# Step 1a, check for changes in devices of users we share a room with
|
||||
users_that_have_changed = await self.store.get_users_whose_devices_changed(
|
||||
since_token.device_list_key, users_who_share_room
|
||||
since_token.device_list_key, tracked_users
|
||||
)
|
||||
|
||||
# Step 1b, check for newly joined rooms
|
||||
|
|
|
@ -193,6 +193,12 @@ class SynapseRequest(Request):
|
|||
self.finish_time = time.time()
|
||||
Request.connectionLost(self, reason)
|
||||
|
||||
if self.logcontext is None:
|
||||
logger.info(
|
||||
"Connection from %s lost before request headers were read", self.client
|
||||
)
|
||||
return
|
||||
|
||||
# we only get here if the connection to the client drops before we send
|
||||
# the response.
|
||||
#
|
||||
|
@ -236,13 +242,6 @@ class SynapseRequest(Request):
|
|||
def _finished_processing(self):
|
||||
"""Log the completion of this request and update the metrics
|
||||
"""
|
||||
|
||||
if self.logcontext is None:
|
||||
# this can happen if the connection closed before we read the
|
||||
# headers (so render was never called). In that case we'll already
|
||||
# have logged a warning, so just bail out.
|
||||
return
|
||||
|
||||
usage = self.logcontext.get_resource_usage()
|
||||
|
||||
if self._processing_finished_time is None:
|
||||
|
|
|
@ -51,7 +51,7 @@ try:
|
|||
|
||||
is_thread_resource_usage_supported = True
|
||||
|
||||
def get_thread_resource_usage():
|
||||
def get_thread_resource_usage() -> "Optional[resource._RUsage]":
|
||||
return resource.getrusage(RUSAGE_THREAD)
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ except Exception:
|
|||
# won't track resource usage.
|
||||
is_thread_resource_usage_supported = False
|
||||
|
||||
def get_thread_resource_usage():
|
||||
def get_thread_resource_usage() -> "Optional[resource._RUsage]":
|
||||
return None
|
||||
|
||||
|
||||
|
@ -201,10 +201,10 @@ class _Sentinel(object):
|
|||
record["request"] = None
|
||||
record["scope"] = None
|
||||
|
||||
def start(self):
|
||||
def start(self, rusage: "Optional[resource._RUsage]"):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
def stop(self, rusage: "Optional[resource._RUsage]"):
|
||||
pass
|
||||
|
||||
def add_database_transaction(self, duration_sec):
|
||||
|
@ -261,7 +261,7 @@ class LoggingContext(object):
|
|||
|
||||
# The thread resource usage when the logcontext became active. None
|
||||
# if the context is not currently active.
|
||||
self.usage_start = None
|
||||
self.usage_start = None # type: Optional[resource._RUsage]
|
||||
|
||||
self.main_thread = get_thread_id()
|
||||
self.request = None
|
||||
|
@ -336,7 +336,17 @@ class LoggingContext(object):
|
|||
record["request"] = self.request
|
||||
record["scope"] = self.scope
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self, rusage: "Optional[resource._RUsage]") -> None:
|
||||
"""
|
||||
Record that this logcontext is currently running.
|
||||
|
||||
This should not be called directly: use set_current_context
|
||||
|
||||
Args:
|
||||
rusage: the resources used by the current thread, at the point of
|
||||
switching to this logcontext. May be None if this platform doesn't
|
||||
support getrusuage.
|
||||
"""
|
||||
if get_thread_id() != self.main_thread:
|
||||
logger.warning("Started logcontext %s on different thread", self)
|
||||
return
|
||||
|
@ -349,36 +359,48 @@ class LoggingContext(object):
|
|||
if self.usage_start:
|
||||
logger.warning("Re-starting already-active log context %s", self)
|
||||
else:
|
||||
self.usage_start = get_thread_resource_usage()
|
||||
self.usage_start = rusage
|
||||
|
||||
def stop(self) -> None:
|
||||
if get_thread_id() != self.main_thread:
|
||||
logger.warning("Stopped logcontext %s on different thread", self)
|
||||
return
|
||||
def stop(self, rusage: "Optional[resource._RUsage]") -> None:
|
||||
"""
|
||||
Record that this logcontext is no longer running.
|
||||
|
||||
# When we stop, let's record the cpu used since we started
|
||||
if not self.usage_start:
|
||||
# Log a warning on platforms that support thread usage tracking
|
||||
if is_thread_resource_usage_supported:
|
||||
This should not be called directly: use set_current_context
|
||||
|
||||
Args:
|
||||
rusage: the resources used by the current thread, at the point of
|
||||
switching away from this logcontext. May be None if this platform
|
||||
doesn't support getrusuage.
|
||||
"""
|
||||
|
||||
try:
|
||||
if get_thread_id() != self.main_thread:
|
||||
logger.warning("Stopped logcontext %s on different thread", self)
|
||||
return
|
||||
|
||||
if not rusage:
|
||||
return
|
||||
|
||||
# Record the cpu used since we started
|
||||
if not self.usage_start:
|
||||
logger.warning(
|
||||
"Called stop on logcontext %s without calling start", self
|
||||
"Called stop on logcontext %s without recording a start rusage",
|
||||
self,
|
||||
)
|
||||
return
|
||||
return
|
||||
|
||||
utime_delta, stime_delta = self._get_cputime()
|
||||
self._resource_usage.ru_utime += utime_delta
|
||||
self._resource_usage.ru_stime += stime_delta
|
||||
utime_delta, stime_delta = self._get_cputime(rusage)
|
||||
self._resource_usage.ru_utime += utime_delta
|
||||
self._resource_usage.ru_stime += stime_delta
|
||||
|
||||
self.usage_start = None
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
if self.parent_context:
|
||||
self.parent_context._resource_usage += self._resource_usage
|
||||
|
||||
# if we have a parent, pass our CPU usage stats on
|
||||
if self.parent_context is not None and hasattr(
|
||||
self.parent_context, "_resource_usage"
|
||||
):
|
||||
self.parent_context._resource_usage += self._resource_usage
|
||||
|
||||
# reset them in case we get entered again
|
||||
self._resource_usage.reset()
|
||||
# reset them in case we get entered again
|
||||
self._resource_usage.reset()
|
||||
finally:
|
||||
self.usage_start = None
|
||||
|
||||
def get_resource_usage(self) -> ContextResourceUsage:
|
||||
"""Get resources used by this logcontext so far.
|
||||
|
@ -394,24 +416,24 @@ class LoggingContext(object):
|
|||
# can include resource usage so far.
|
||||
is_main_thread = get_thread_id() == self.main_thread
|
||||
if self.usage_start and is_main_thread:
|
||||
utime_delta, stime_delta = self._get_cputime()
|
||||
rusage = get_thread_resource_usage()
|
||||
assert rusage is not None
|
||||
utime_delta, stime_delta = self._get_cputime(rusage)
|
||||
res.ru_utime += utime_delta
|
||||
res.ru_stime += stime_delta
|
||||
|
||||
return res
|
||||
|
||||
def _get_cputime(self) -> Tuple[float, float]:
|
||||
"""Get the cpu usage time so far
|
||||
def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
|
||||
"""Get the cpu usage time between start() and the given rusage
|
||||
|
||||
Args:
|
||||
rusage: the current resource usage
|
||||
|
||||
Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
|
||||
"""
|
||||
assert self.usage_start is not None
|
||||
|
||||
current = get_thread_resource_usage()
|
||||
|
||||
# Indicate to mypy that we know that self.usage_start is None.
|
||||
assert self.usage_start is not None
|
||||
|
||||
utime_delta = current.ru_utime - self.usage_start.ru_utime
|
||||
stime_delta = current.ru_stime - self.usage_start.ru_stime
|
||||
|
||||
|
@ -539,12 +561,19 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
|
|||
Returns:
|
||||
The context that was previously active
|
||||
"""
|
||||
# everything blows up if we allow current_context to be set to None, so sanity-check
|
||||
# that now.
|
||||
if context is None:
|
||||
raise TypeError("'context' argument may not be None")
|
||||
|
||||
current = current_context()
|
||||
|
||||
if current is not context:
|
||||
current.stop()
|
||||
rusage = get_thread_resource_usage()
|
||||
current.stop(rusage)
|
||||
_thread_local.current_context = context
|
||||
context.start()
|
||||
context.start(rusage)
|
||||
|
||||
return current
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.replication.http import (
|
|||
membership,
|
||||
register,
|
||||
send_event,
|
||||
streams,
|
||||
)
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
|
@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
|
|||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
"""Fetches stream updates from a server. Used for streams not persisted to
|
||||
the database, e.g. typing notifications.
|
||||
|
||||
The API looks like:
|
||||
|
||||
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
updates: [ ... ],
|
||||
upto_token: 10,
|
||||
limited: False,
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
NAME = "get_repl_stream_updates"
|
||||
PATH_ARGS = ("stream_name",)
|
||||
METHOD = "GET"
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token, limit):
|
||||
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
stream = self.streams.get(stream_name)
|
||||
if stream is None:
|
||||
raise SynapseError(400, "Unknown stream")
|
||||
|
||||
from_token = parse_integer(request, "from_token", required=True)
|
||||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
from_token, upto_token, limit
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{"updates": updates, "upto_token": upto_token, "limited": limited},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationGetStreamUpdates(hs).register(http_server)
|
|
@ -18,8 +18,10 @@ from typing import Dict, Optional
|
|||
|
||||
import six
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
|
||||
from synapse.storage.data_stores.main.cache import (
|
||||
CURRENT_STATE_CACHE_NAME,
|
||||
CacheInvalidationWorkerStore,
|
||||
)
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
|
@ -35,7 +37,7 @@ def __func__(inp):
|
|||
return inp.__func__
|
||||
|
||||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
|
@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
|
|||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
if self._cache_id_gen:
|
||||
|
|
|
@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
|||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "pushers":
|
||||
self._pushers_id_gen.advance(token)
|
||||
|
|
|
@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
self.client_name = client_name
|
||||
self.handler = handler
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||
|
@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
|||
def buildProtocol(self, addr):
|
||||
logger.info("Connected to replication: %r", addr)
|
||||
return ClientReplicationStreamProtocol(
|
||||
self.client_name, self.server_name, self._clock, self.handler
|
||||
self.hs, self.client_name, self.server_name, self._clock, self.handler,
|
||||
)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
|
@ -188,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
|
|||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
)
|
||||
|
||||
def send_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
|
|
|
@ -136,8 +136,8 @@ class PositionCommand(Command):
|
|||
"""Sent by the server to tell the client the stream postition without
|
||||
needing to send an RDATA.
|
||||
|
||||
Sent to the client after all missing updates for a stream have been sent
|
||||
to the client and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed
|
||||
any updates, and if so then fetch them out of band.
|
||||
"""
|
||||
|
||||
NAME = "POSITION"
|
||||
|
@ -179,42 +179,24 @@ class NameCommand(Command):
|
|||
|
||||
|
||||
class ReplicateCommand(Command):
|
||||
"""Sent by the client to subscribe to the stream.
|
||||
"""Sent by the client to subscribe to streams.
|
||||
|
||||
Format::
|
||||
|
||||
REPLICATE <stream_name> <token>
|
||||
|
||||
Where <token> may be either:
|
||||
* a numeric stream_id to stream updates from
|
||||
* "NOW" to stream all subsequent updates.
|
||||
|
||||
The <stream_name> can be "ALL" to subscribe to all known streams, in which
|
||||
case the <token> must be set to "NOW", i.e.::
|
||||
|
||||
REPLICATE ALL NOW
|
||||
REPLICATE
|
||||
"""
|
||||
|
||||
NAME = "REPLICATE"
|
||||
|
||||
def __init__(self, stream_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.token = token
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token = line.split(" ", 1)
|
||||
if token in ("NOW", "now"):
|
||||
token = "NOW"
|
||||
else:
|
||||
token = int(token)
|
||||
return cls(stream_name, token)
|
||||
return cls()
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, str(self.token)))
|
||||
|
||||
def get_logcontext_id(self):
|
||||
return "REPLICATE-" + self.stream_name
|
||||
return ""
|
||||
|
||||
|
||||
class UserSyncCommand(Command):
|
||||
|
@ -225,30 +207,32 @@ class UserSyncCommand(Command):
|
|||
|
||||
Format::
|
||||
|
||||
USER_SYNC <user_id> <state> <last_sync_ms>
|
||||
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
|
||||
|
||||
Where <state> is either "start" or "stop"
|
||||
"""
|
||||
|
||||
NAME = "USER_SYNC"
|
||||
|
||||
def __init__(self, user_id, is_syncing, last_sync_ms):
|
||||
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
self.instance_id = instance_id
|
||||
self.user_id = user_id
|
||||
self.is_syncing = is_syncing
|
||||
self.last_sync_ms = last_sync_ms
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
user_id, state, last_sync_ms = line.split(" ", 2)
|
||||
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
||||
|
||||
if state not in ("start", "end"):
|
||||
raise Exception("Invalid USER_SYNC state %r" % (state,))
|
||||
|
||||
return cls(user_id, state == "start", int(last_sync_ms))
|
||||
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join(
|
||||
(
|
||||
self.instance_id,
|
||||
self.user_id,
|
||||
"start" if self.is_syncing else "end",
|
||||
str(self.last_sync_ms),
|
||||
|
@ -256,6 +240,30 @@ class UserSyncCommand(Command):
|
|||
)
|
||||
|
||||
|
||||
class ClearUserSyncsCommand(Command):
|
||||
"""Sent by the client to inform the server that it should drop all
|
||||
information about syncing users sent by the client.
|
||||
|
||||
Mainly used when client is about to shut down.
|
||||
|
||||
Format::
|
||||
|
||||
CLEAR_USER_SYNC <instance_id>
|
||||
"""
|
||||
|
||||
NAME = "CLEAR_USER_SYNC"
|
||||
|
||||
def __init__(self, instance_id):
|
||||
self.instance_id = instance_id
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
return cls(line)
|
||||
|
||||
def to_line(self):
|
||||
return self.instance_id
|
||||
|
||||
|
||||
class FederationAckCommand(Command):
|
||||
"""Sent by the client when it has processed up to a given point in the
|
||||
federation stream. This allows the master to drop in-memory caches of the
|
||||
|
@ -416,6 +424,7 @@ _COMMANDS = (
|
|||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
) # type: Tuple[Type[Command], ...]
|
||||
|
||||
# Map of command name to command type.
|
||||
|
@ -438,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
|
|||
ReplicateCommand.NAME,
|
||||
PingCommand.NAME,
|
||||
UserSyncCommand.NAME,
|
||||
ClearUserSyncsCommand.NAME,
|
||||
FederationAckCommand.NAME,
|
||||
RemovePusherCommand.NAME,
|
||||
InvalidateCacheCommand.NAME,
|
||||
|
|
|
@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
|
|||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
|
@ -53,17 +51,15 @@ import fcntl
|
|||
import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Set, Tuple
|
||||
from typing import Any, DefaultDict, Dict, List, Set
|
||||
|
||||
from six import iteritems, iterkeys
|
||||
from six import iteritems
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
|
@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
|
|||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.types import Collection
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
connection_close_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||
)
|
||||
|
@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.server_name = server_name
|
||||
self.streamer = streamer
|
||||
|
||||
# The streams the client has subscribed to and is up to date with
|
||||
self.replication_streams = set() # type: Set[str]
|
||||
|
||||
# The streams the client is currently subscribing to.
|
||||
self.connecting_streams = set() # type: Set[str]
|
||||
|
||||
# Map from stream name to list of updates to send once we've finished
|
||||
# subscribing the client to the stream.
|
||||
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(ServerCommand(self.server_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
|
@ -432,25 +423,17 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
|
||||
async def on_USER_SYNC(self, cmd):
|
||||
await self.streamer.on_user_sync(
|
||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_CLEAR_USER_SYNC(self, cmd):
|
||||
await self.streamer.on_clear_user_syncs(cmd.instance_id)
|
||||
|
||||
async def on_REPLICATE(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
token = cmd.token
|
||||
|
||||
if stream_name == "ALL":
|
||||
# Subscribe to all streams we're publishing to.
|
||||
deferreds = [
|
||||
run_in_background(self.subscribe_to_stream, stream, token)
|
||||
for stream in iterkeys(self.streamer.streams_by_name)
|
||||
]
|
||||
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
else:
|
||||
await self.subscribe_to_stream(stream_name, token)
|
||||
# Subscribe to all streams we're publishing to.
|
||||
for stream_name in self.streamer.streams_by_name:
|
||||
current_token = self.streamer.get_stream_token(stream_name)
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd):
|
||||
self.streamer.federation_ack(cmd.token)
|
||||
|
@ -474,87 +457,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
cmd.last_seen,
|
||||
)
|
||||
|
||||
async def subscribe_to_stream(self, stream_name, token):
|
||||
"""Subscribe the remote to a stream.
|
||||
|
||||
This invloves checking if they've missed anything and sending those
|
||||
updates down if they have. During that time new updates for the stream
|
||||
are queued and sent once we've sent down any missed updates.
|
||||
"""
|
||||
self.replication_streams.discard(stream_name)
|
||||
self.connecting_streams.add(stream_name)
|
||||
|
||||
try:
|
||||
# Get missing updates
|
||||
updates, current_token = await self.streamer.get_stream_updates(
|
||||
stream_name, token
|
||||
)
|
||||
|
||||
# Send all the missing updates
|
||||
for update in updates:
|
||||
token, row = update[0], update[1]
|
||||
self.send_command(RdataCommand(stream_name, token, row))
|
||||
|
||||
# We send a POSITION command to ensure that they have an up to
|
||||
# date token (especially useful if we didn't send any updates
|
||||
# above)
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
# Now we can send any updates that came in while we were subscribing
|
||||
pending_rdata = self.pending_rdata.pop(stream_name, [])
|
||||
updates = []
|
||||
for token, update in pending_rdata:
|
||||
# If the token is null, it is part of a batch update. Batches
|
||||
# are multiple updates that share a single token. To denote
|
||||
# this, the token is set to None for all tokens in the batch
|
||||
# except for the last. If we find a None token, we keep looking
|
||||
# through tokens until we find one that is not None and then
|
||||
# process all previous updates in the batch as if they had the
|
||||
# final token.
|
||||
if token is None:
|
||||
# Store this update as part of a batch
|
||||
updates.append(update)
|
||||
continue
|
||||
|
||||
if token <= current_token:
|
||||
# This update or batch of updates is older than
|
||||
# current_token, dismiss it
|
||||
updates = []
|
||||
continue
|
||||
|
||||
updates.append(update)
|
||||
|
||||
# Send all updates that are part of this batch with the
|
||||
# found token
|
||||
for update in updates:
|
||||
self.send_command(RdataCommand(stream_name, token, update))
|
||||
|
||||
# Clear stored updates
|
||||
updates = []
|
||||
|
||||
# They're now fully subscribed
|
||||
self.replication_streams.add(stream_name)
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
|
||||
self.send_error("failed to handle replicate: %r", e)
|
||||
finally:
|
||||
self.connecting_streams.discard(stream_name)
|
||||
|
||||
def stream_update(self, stream_name, token, data):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
if stream_name in self.replication_streams:
|
||||
# The client is subscribed to the stream
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
elif stream_name in self.connecting_streams:
|
||||
# The client is being subscribed to the stream
|
||||
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
|
||||
self.pending_rdata.setdefault(stream_name, []).append((token, data))
|
||||
else:
|
||||
# The client isn't subscribed
|
||||
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
|
||||
def send_sync(self, data):
|
||||
self.send_command(SyncCommand(data))
|
||||
|
@ -638,6 +546,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
client_name: str,
|
||||
server_name: str,
|
||||
clock: Clock,
|
||||
|
@ -645,41 +554,42 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
self.client_name = client_name
|
||||
self.server_name = server_name
|
||||
self.handler = handler
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
# Set of stream names that have been subscribe to, but haven't yet
|
||||
# caught up with. This is used to track when the client has been fully
|
||||
# connected to the remote.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, Any]
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
|
||||
# Once we've connected subscribe to the necessary streams
|
||||
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
|
||||
self.replicate(stream_name, token)
|
||||
self.replicate()
|
||||
|
||||
# Tell the server if we have any users currently syncing (should only
|
||||
# happen on synchrotrons)
|
||||
currently_syncing = self.handler.get_currently_syncing_users()
|
||||
now = self.clock.time_msec()
|
||||
for user_id in currently_syncing:
|
||||
self.send_command(UserSyncCommand(user_id, True, now))
|
||||
self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
|
||||
|
||||
# We've now finished connecting to so inform the client handler
|
||||
self.handler.update_connection(self)
|
||||
|
||||
# This will happen if we don't actually subscribe to any streams
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
async def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
|
@ -697,7 +607,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
if cmd.token is None or stream_name in self.streams_connecting:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
|
@ -707,14 +617,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
rows.append(row)
|
||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd):
|
||||
# When we get a `POSITION` command it means we've finished getting
|
||||
# missing updates for the given stream, and are now up to date.
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self.streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
return
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
|
||||
if current_token is None:
|
||||
logger.warning(
|
||||
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch all updates between then and now.
|
||||
limited = True
|
||||
while limited:
|
||||
updates, current_token, limited = await stream.get_updates_since(
|
||||
current_token, cmd.token
|
||||
)
|
||||
|
||||
# Check if the connection was closed underneath us, if so we bail
|
||||
# rather than risk having concurrent catch ups going on.
|
||||
if self.state == ConnectionStates.CLOSED:
|
||||
return
|
||||
|
||||
if updates:
|
||||
await self.handler.on_rdata(
|
||||
cmd.stream_name,
|
||||
current_token,
|
||||
[stream.parse_row(update[1]) for update in updates],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
# Check if the connection was closed underneath us, if so we bail
|
||||
# rather than risk having concurrent catch ups going on.
|
||||
if self.state == ConnectionStates.CLOSED:
|
||||
return
|
||||
|
||||
# Handle any RDATA that came in while we were catching up.
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
if rows:
|
||||
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||
|
||||
async def on_SYNC(self, cmd):
|
||||
self.handler.on_sync(cmd.data)
|
||||
|
@ -722,22 +673,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.handler.on_remote_server_up(cmd.data)
|
||||
|
||||
def replicate(self, stream_name, token):
|
||||
def replicate(self):
|
||||
"""Send the subscription request to the server
|
||||
"""
|
||||
if stream_name not in STREAMS_MAP:
|
||||
raise Exception("Invalid stream name %r" % (stream_name,))
|
||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||
|
||||
logger.info(
|
||||
"[%s] Subscribing to replication stream: %r from %r",
|
||||
self.id(),
|
||||
stream_name,
|
||||
token,
|
||||
)
|
||||
|
||||
self.streams_connecting.add(stream_name)
|
||||
|
||||
self.send_command(ReplicateCommand(stream_name, token))
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
def on_connection_closed(self):
|
||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from six import itervalues
|
||||
|
||||
|
@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.util.metrics import Measure, measure_func
|
||||
|
||||
from .protocol import ServerReplicationStreamProtocol
|
||||
from .streams import STREAMS_MAP
|
||||
from .streams import STREAMS_MAP, Stream
|
||||
from .streams.federation import FederationStream
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
|
@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
|
|||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.streamer = ReplicationStreamer(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
|
||||
|
@ -99,22 +99,6 @@ class ReplicationStreamer(object):
|
|||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_connections_per_stream",
|
||||
"",
|
||||
["stream_name"],
|
||||
lambda: {
|
||||
(stream_name,): len(
|
||||
[
|
||||
conn
|
||||
for conn in self.connections
|
||||
if stream_name in conn.replication_streams
|
||||
]
|
||||
)
|
||||
for stream_name in self.streams_by_name
|
||||
},
|
||||
)
|
||||
|
||||
self.federation_sender = None
|
||||
if not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
@ -133,6 +117,11 @@ class ReplicationStreamer(object):
|
|||
for conn in self.connections:
|
||||
conn.send_error("server shutting down")
|
||||
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a mapp from stream name to stream instance.
|
||||
"""
|
||||
return self.streams_by_name
|
||||
|
||||
def on_notifier_poke(self):
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
connections if there are.
|
||||
|
@ -190,7 +179,8 @@ class ReplicationStreamer(object):
|
|||
stream.current_token(),
|
||||
)
|
||||
try:
|
||||
updates, current_token = await stream.get_updates()
|
||||
updates, current_token, limited = await stream.get_updates()
|
||||
self.pending_updates |= limited
|
||||
except Exception:
|
||||
logger.info("Failed to handle stream %s", stream.NAME)
|
||||
raise
|
||||
|
@ -226,8 +216,7 @@ class ReplicationStreamer(object):
|
|||
self.pending_updates = False
|
||||
self.is_looping = False
|
||||
|
||||
@measure_func("repl.get_stream_updates")
|
||||
async def get_stream_updates(self, stream_name, token):
|
||||
def get_stream_token(self, stream_name):
|
||||
"""For a given stream get all updates since token. This is called when
|
||||
a client first subscribes to a stream.
|
||||
"""
|
||||
|
@ -235,7 +224,7 @@ class ReplicationStreamer(object):
|
|||
if not stream:
|
||||
raise Exception("unknown stream %s", stream_name)
|
||||
|
||||
return await stream.get_updates_since(token)
|
||||
return stream.current_token()
|
||||
|
||||
@measure_func("repl.federation_ack")
|
||||
def federation_ack(self, token):
|
||||
|
@ -246,14 +235,19 @@ class ReplicationStreamer(object):
|
|||
self.federation_sender.federation_ack(token)
|
||||
|
||||
@measure_func("repl.on_user_sync")
|
||||
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||
async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
"""A client has started/stopped syncing on a worker.
|
||||
"""
|
||||
user_sync_counter.inc()
|
||||
await self.presence_handler.update_external_syncs_row(
|
||||
conn_id, user_id, is_syncing, last_sync_ms
|
||||
instance_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
async def on_clear_user_syncs(self, instance_id):
|
||||
"""A replication client wants us to drop all their UserSync data.
|
||||
"""
|
||||
await self.presence_handler.update_external_syncs_clear(instance_id)
|
||||
|
||||
@measure_func("repl.on_remove_pusher")
|
||||
async def on_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""A client has asked us to remove a pusher
|
||||
|
@ -316,14 +310,6 @@ class ReplicationStreamer(object):
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
# We need to tell the presence handler that the connection has been
|
||||
# lost so that it can handle any ongoing syncs on that connection.
|
||||
run_as_background_process(
|
||||
"update_external_syncs_clear",
|
||||
self.presence_handler.update_external_syncs_clear,
|
||||
connection.conn_id,
|
||||
)
|
||||
|
||||
|
||||
def _batch_updates(updates):
|
||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||
|
|
|
@ -24,6 +24,9 @@ Each stream is defined by the following information:
|
|||
current_token: The function that returns the current token for the stream
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
AccountDataStream,
|
||||
BackfillStream,
|
||||
|
@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
|
|||
PushersStream,
|
||||
PushRulesStream,
|
||||
ReceiptsStream,
|
||||
Stream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
|
@ -63,10 +67,12 @@ STREAMS_MAP = {
|
|||
GroupServerStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
}
|
||||
} # type: Dict[str, Type[Stream]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STREAMS_MAP",
|
||||
"Stream",
|
||||
"BackfillStream",
|
||||
"PresenceStream",
|
||||
"TypingStream",
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
|
|||
MAX_EVENTS_BEHIND = 500000
|
||||
|
||||
|
||||
# Some type aliases to make things a bit easier.
|
||||
|
||||
# A stream position token
|
||||
Token = int
|
||||
|
||||
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||
StreamRow = Tuple[Token, tuple]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
"""Base class for the streams.
|
||||
|
||||
|
@ -56,6 +65,7 @@ class Stream(object):
|
|||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
|
||||
|
@ -65,61 +75,46 @@ class Stream(object):
|
|||
"""
|
||||
self.last_token = self.current_token()
|
||||
|
||||
async def get_updates(self):
|
||||
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
since the stream was constructed if it hadn't been called before).
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
updates, current_token = await self.get_updates_since(self.last_token)
|
||||
current_token = self.current_token()
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.last_token, current_token
|
||||
)
|
||||
self.last_token = current_token
|
||||
|
||||
return updates, current_token
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(
|
||||
self, from_token: int
|
||||
) -> Tuple[List[Tuple[int, JsonDict]], int]:
|
||||
self, from_token: Token, upto_token: Token, limit: int = 100
|
||||
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
||||
Returns:
|
||||
Resolves to a pair `(updates, new_last_token)`, where `updates` is
|
||||
a list of `(token, row)` entries and `new_last_token` is the new
|
||||
position in stream.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
|
||||
if from_token in ("NOW", "now"):
|
||||
return [], self.current_token()
|
||||
|
||||
current_token = self.current_token()
|
||||
|
||||
from_token = int(from_token)
|
||||
|
||||
if from_token == current_token:
|
||||
return [], current_token
|
||||
if from_token == upto_token:
|
||||
return [], upto_token, False
|
||||
|
||||
rows = await self.update_function(
|
||||
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, limit=limit,
|
||||
)
|
||||
|
||||
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
||||
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
||||
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
|
||||
# check we didn't get more rows than the limit.
|
||||
# doing it like this allows the update_function to be a generator.
|
||||
if len(updates) >= MAX_EVENTS_BEHIND:
|
||||
raise Exception("stream %s has fallen behind" % (self.NAME))
|
||||
|
||||
# The update function didn't hit the limit, so we must have got all
|
||||
# the updates to `current_token`, and can return that as our new
|
||||
# stream position.
|
||||
return updates, current_token
|
||||
return updates, upto_token, limited
|
||||
|
||||
def current_token(self):
|
||||
"""Gets the current token of the underlying streams. Should be provided
|
||||
|
@ -141,6 +136,48 @@ class Stream(object):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
upto_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(
|
||||
hs, stream_name: str
|
||||
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
|
||||
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
return await client(
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
class BackfillStream(Stream):
|
||||
"""We fetched some old events and either we had never seen that event before
|
||||
or it went from being an outlier to not.
|
||||
|
@ -164,7 +201,7 @@ class BackfillStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
|
||||
|
@ -190,8 +227,15 @@ class PresenceStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
self.update_function = presence_handler.get_all_presence_updates # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
|
||||
|
@ -208,7 +252,12 @@ class TypingStream(Stream):
|
|||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
self.update_function = typing_handler.get_all_typing_updates # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(TypingStream, self).__init__(hs)
|
||||
|
||||
|
@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_receipts # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
|
||||
|
@ -256,7 +305,13 @@ class PushRulesStream(Stream):
|
|||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
return [(row[0], row[2]) for row in rows]
|
||||
|
||||
limited = False
|
||||
if len(rows) == limit:
|
||||
to_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return [(row[0], (row[2],)) for row in rows], to_token, limited
|
||||
|
||||
|
||||
class PushersStream(Stream):
|
||||
|
@ -275,7 +330,7 @@ class PushersStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_pushers_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
|
||||
|
@ -307,7 +362,7 @@ class CachesStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_caches # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
|
||||
|
@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = store.get_all_new_public_rooms # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
|
||||
|
@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
|
||||
|
@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_new_device_messages # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
|
@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_tags # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
|
||||
|
@ -412,10 +467,11 @@ class AccountDataStream(Stream):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
)
|
||||
|
@ -442,7 +498,7 @@ class GroupServerStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = store.get_all_groups_changes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Tuple, Type
|
|||
|
||||
import attr
|
||||
|
||||
from ._base import Stream
|
||||
from ._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
@ -117,10 +117,11 @@ class EventsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, current_token, limit=None):
|
||||
async def _update_function(self, from_token, current_token, limit=None):
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from ._base import Stream
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
|
@ -33,11 +35,18 @@ class FederationStream(Stream):
|
|||
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
_QUERY_MASTER = True
|
||||
|
||||
def __init__(self, hs):
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = federation_sender.get_replication_rows # type: ignore
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
federation_sender = hs.get_federation_sender()
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||
else:
|
||||
self.current_token = lambda: 0 # type: ignore
|
||||
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
<html>
|
||||
<head>
|
||||
<title>Authentication</title>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>
|
||||
A client is trying to {{ description | e }}. To confirm this action,
|
||||
<a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
|
||||
If you did not expect this, your account may be compromised!
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
|
@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
|
|||
keys,
|
||||
notifications,
|
||||
openid,
|
||||
password_policy,
|
||||
read_marker,
|
||||
receipts,
|
||||
register,
|
||||
|
@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
|
|||
capabilities.register_servlets(hs, client_resource)
|
||||
account_validity.register_servlets(hs, client_resource)
|
||||
relations.register_servlets(hs, client_resource)
|
||||
password_policy.register_servlets(hs, client_resource)
|
||||
|
||||
# moving to /_synapse/admin
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource(
|
||||
|
|
|
@ -29,7 +29,11 @@ from synapse.rest.admin._base import (
|
|||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
|
||||
from synapse.rest.admin.rooms import (
|
||||
JoinRoomAliasServlet,
|
||||
ListRoomRestServlet,
|
||||
ShutdownRoomRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
from synapse.rest.admin.users import (
|
||||
AccountValidityRenewServlet,
|
||||
|
@ -189,6 +193,7 @@ def register_servlets(hs, http_server):
|
|||
"""
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
PurgeRoomServlet(hs).register(http_server)
|
||||
SendServerNoticeServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,9 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
|
@ -29,7 +30,7 @@ from synapse.rest.admin._base import (
|
|||
historical_admin_path_patterns,
|
||||
)
|
||||
from synapse.storage.data_stores.main.room import RoomSortOrder
|
||||
from synapse.types import create_requester
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet):
|
|||
response["prev_batch"] = 0
|
||||
|
||||
return 200, response
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.admin_handler = hs.get_handlers().admin_handler
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
async def on_POST(self, request, room_identifier):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(content, ["user_id"])
|
||||
target_user = UserID.from_string(content["user_id"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "This endpoint can only be used with local users")
|
||||
|
||||
if not await self.admin_handler.get_user(target_user):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
] # type: Optional[List[str]]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
handler = self.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
fake_requester = create_requester(target_user)
|
||||
|
||||
# send invite if room has "JoinRules.INVITE"
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=fake_requester.user,
|
||||
room_id=room_id,
|
||||
action="invite",
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=fake_requester,
|
||||
target=fake_requester.user,
|
||||
room_id=room_id,
|
||||
action="join",
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
return 200, {"room_id": room_id}
|
||||
|
|
|
@ -14,11 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
|
@ -28,9 +23,10 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.types import UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
def build_service_param(cas_service_url, client_redirect_url):
|
||||
return "%s%s?redirectUrl=%s" % (
|
||||
cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.quote(client_redirect_url, safe=""),
|
||||
)
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login$", v1=True)
|
||||
CAS_TYPE = "m.login.cas"
|
||||
|
@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
|
@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
client_redirect_url (bytes): the URL that we should redirect the
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
Returns:
|
||||
bytes: URL to redirect to
|
||||
URL to redirect to
|
||||
"""
|
||||
# to be implemented by subclasses
|
||||
raise NotImplementedError()
|
||||
|
@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
|
||||
class CasRedirectServlet(BaseSSORedirectServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
args = urllib.parse.urlencode(
|
||||
{"service": build_service_param(self.cas_service_url, client_redirect_url)}
|
||||
)
|
||||
|
||||
return "%s/login?%s" % (self.cas_server_url, args)
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._cas_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class CasTicketServlet(RestServlet):
|
||||
|
@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(CasTicketServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> None:
|
||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": build_service_param(self.cas_service_url, client_redirect_url),
|
||||
}
|
||||
try:
|
||||
body = await self._http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
result = await self.handle_cas_response(request, body, client_redirect_url)
|
||||
return result
|
||||
|
||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||
user, attributes = self.parse_cas_response(cas_response_body)
|
||||
displayname = attributes.pop(self.cas_displayname_attribute, None)
|
||||
|
||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url, displayname
|
||||
ticket = parse_string(request, "ticket", required=True)
|
||||
await self._cas_handler.handle_ticket_request(
|
||||
request, client_redirect_url, ticket
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
user = None
|
||||
attributes = {}
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.exception("Error parsing CAS response")
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
class SAMLRedirectServlet(BaseSSORedirectServlet):
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
|||
def __init__(self, hs):
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class SSOAuthHandler(object):
|
||||
"""
|
||||
Utility class for Resources and Servlets which handle the response from a SSO
|
||||
service
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
# cast to tuple for use with str.startswith
|
||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||
|
||||
async def on_successful_auth(
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
Registers the user if necessary, and then returns a redirect (with
|
||||
a login token) to the client.
|
||||
|
||||
Args:
|
||||
username (unicode|bytes): the remote user id. We'll map this onto
|
||||
something sane for a MXID localpath.
|
||||
|
||||
request (SynapseRequest): the incoming request from the browser. We'll
|
||||
respond to it with a redirect.
|
||||
|
||||
client_redirect_url (unicode): the redirect_url the client gave us when
|
||||
it first started the process.
|
||||
|
||||
user_display_name (unicode|None): if set, and we have to register a new user,
|
||||
we will set their displayname to this.
|
||||
|
||||
Returns:
|
||||
Deferred[none]: Completes once we have handled the request.
|
||||
"""
|
||||
localpart = map_username_to_mxid_localpart(username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
)
|
||||
|
||||
self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if hs.config.cas_enabled:
|
||||
|
|
|
@ -234,13 +234,21 @@ class PasswordRestServlet(RestServlet):
|
|||
if self.auth.has_access_token(request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
params = await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"modify your account password",
|
||||
)
|
||||
user_id = requester.user.to_string()
|
||||
else:
|
||||
requester = None
|
||||
result, params, _ = await self.auth_handler.check_auth(
|
||||
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
|
||||
[[LoginType.EMAIL_IDENTITY]],
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"modify your account password",
|
||||
)
|
||||
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
|
@ -308,7 +316,11 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"deactivate your account",
|
||||
)
|
||||
result = await self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase, id_server=body.get("id_server")
|
||||
|
@ -602,6 +614,11 @@ class ThreepidRestServlet(RestServlet):
|
|||
return 200, {"threepids": threepids}
|
||||
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -646,6 +663,11 @@ class ThreepidAddRestServlet(RestServlet):
|
|||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
@ -656,7 +678,11 @@ class ThreepidAddRestServlet(RestServlet):
|
|||
assert_valid_client_secret(client_secret)
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"add a third-party identifier to your account",
|
||||
)
|
||||
|
||||
validation_session = await self.identity_handler.validate_threepid_session(
|
||||
|
@ -741,10 +767,16 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidDeleteRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
if not self.hs.config.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["medium", "address"])
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import logging
|
|||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.handlers.auth import SUCCESS_TEMPLATE
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import RestServlet, parse_string
|
||||
|
||||
|
@ -89,30 +90,6 @@ TERMS_TEMPLATE = """
|
|||
</html>
|
||||
"""
|
||||
|
||||
SUCCESS_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Success!</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
if (window.onAuthDone) {
|
||||
window.onAuthDone();
|
||||
} else if (window.opener && window.opener.postMessage) {
|
||||
window.opener.postMessage("authDone", "*");
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class AuthRestServlet(RestServlet):
|
||||
"""
|
||||
|
@ -130,6 +107,11 @@ class AuthRestServlet(RestServlet):
|
|||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
# SSO configuration.
|
||||
self._saml_enabled = hs.config.saml2_enabled
|
||||
if self._saml_enabled:
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
def on_GET(self, request, stagetype):
|
||||
session = parse_string(request, "session")
|
||||
if not session:
|
||||
|
@ -150,6 +132,15 @@ class AuthRestServlet(RestServlet):
|
|||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
|
||||
elif stagetype == LoginType.SSO and self._saml_enabled:
|
||||
# Display a confirmation page which prompts the user to
|
||||
# re-authenticate with their SSO provider.
|
||||
client_redirect_url = ""
|
||||
sso_redirect_url = self._saml_handler.handle_redirect_request(
|
||||
client_redirect_url, session
|
||||
)
|
||||
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
@ -210,6 +201,9 @@ class AuthRestServlet(RestServlet):
|
|||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
elif stagetype == LoginType.SSO:
|
||||
# The SSO fallback workflow should not post here,
|
||||
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
|
|
|
@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
assert_params_in_dict(body, ["devices"])
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"remove device(s) from your account",
|
||||
)
|
||||
|
||||
await self.device_handler.delete_devices(
|
||||
|
@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet):
|
|||
raise
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"remove a device from your account",
|
||||
)
|
||||
|
||||
await self.device_handler.delete_device(requester.user.to_string(), device_id)
|
||||
|
|
|
@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet):
|
|||
body = parse_json_object_from_request(request)
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"add a device signing key to your account",
|
||||
)
|
||||
|
||||
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordPolicyServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/password_policy$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(PasswordPolicyServlet, self).__init__()
|
||||
|
||||
self.policy = hs.config.password_policy
|
||||
self.enabled = hs.config.password_policy_enabled
|
||||
|
||||
def on_GET(self, request):
|
||||
if not self.enabled or not self.policy:
|
||||
return (200, {})
|
||||
|
||||
policy = {}
|
||||
|
||||
for param in [
|
||||
"minimum_length",
|
||||
"require_digit",
|
||||
"require_symbol",
|
||||
"require_lowercase",
|
||||
"require_uppercase",
|
||||
]:
|
||||
if param in self.policy:
|
||||
policy["m.%s" % param] = self.policy[param]
|
||||
|
||||
return (200, policy)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordPolicyServlet(hs).register(http_server)
|
|
@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
|
|||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||
self.password_policy_handler = hs.get_password_policy_handler()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self._registration_flows = _calculate_registration_flows(
|
||||
|
@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
|
|||
or len(body["password"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
self.password_policy_handler.validate_password(body["password"])
|
||||
|
||||
desired_username = None
|
||||
if "username" in body:
|
||||
|
@ -499,7 +501,11 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
auth_result, params, session_id = await self.auth_handler.check_auth(
|
||||
self._registration_flows, body, self.hs.get_ip_from_request(request)
|
||||
self._registration_flows,
|
||||
request,
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
"register a new account",
|
||||
)
|
||||
|
||||
# Check that we're not trying to register a denied 3pid.
|
||||
|
|
|
@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
|
|||
"""
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version")
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
room_keys = await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
|
|
|
@ -56,6 +56,7 @@ from synapse.handlers.account_validity import AccountValidityHandler
|
|||
from synapse.handlers.acme import AcmeHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||
from synapse.handlers.cas_handler import CasHandler
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
|
@ -66,6 +67,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH
|
|||
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.message import EventCreationHandler, MessageHandler
|
||||
from synapse.handlers.pagination import PaginationHandler
|
||||
from synapse.handlers.password_policy import PasswordPolicyHandler
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
|
@ -85,6 +87,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
|||
from synapse.notifier import Notifier
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||
from synapse.rest.media.v1.media_repository import (
|
||||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
|
@ -100,6 +103,7 @@ from synapse.storage import DataStores, Storage
|
|||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -196,9 +200,12 @@ class HomeServer(object):
|
|||
"sendmail",
|
||||
"registration_handler",
|
||||
"account_validity_handler",
|
||||
"cas_handler",
|
||||
"saml_handler",
|
||||
"event_client_serializer",
|
||||
"password_policy_handler",
|
||||
"storage",
|
||||
"replication_streamer",
|
||||
]
|
||||
|
||||
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
|
||||
|
@ -224,6 +231,8 @@ class HomeServer(object):
|
|||
self._listening_services = []
|
||||
self.start_time = None
|
||||
|
||||
self.instance_id = random_string(5)
|
||||
|
||||
self.clock = Clock(reactor)
|
||||
self.distributor = Distributor()
|
||||
self.ratelimiter = Ratelimiter()
|
||||
|
@ -236,6 +245,14 @@ class HomeServer(object):
|
|||
for depname in kwargs:
|
||||
setattr(self, depname, kwargs[depname])
|
||||
|
||||
def get_instance_id(self):
|
||||
"""A unique ID for this synapse process instance.
|
||||
|
||||
This is used to distinguish running instances in worker-based
|
||||
deployments.
|
||||
"""
|
||||
return self.instance_id
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.start_time = int(self.get_clock().time())
|
||||
|
@ -525,6 +542,9 @@ class HomeServer(object):
|
|||
def build_account_validity_handler(self):
|
||||
return AccountValidityHandler(self)
|
||||
|
||||
def build_cas_handler(self):
|
||||
return CasHandler(self)
|
||||
|
||||
def build_saml_handler(self):
|
||||
from synapse.handlers.saml_handler import SamlHandler
|
||||
|
||||
|
@ -533,9 +553,15 @@ class HomeServer(object):
|
|||
def build_event_client_serializer(self):
|
||||
return EventClientSerializer(self)
|
||||
|
||||
def build_password_policy_handler(self):
|
||||
return PasswordPolicyHandler(self)
|
||||
|
||||
def build_storage(self) -> Storage:
|
||||
return Storage(self, self.datastores)
|
||||
|
||||
def build_replication_streamer(self) -> ReplicationStreamer:
|
||||
return ReplicationStreamer(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
@ -557,24 +583,22 @@ def _make_dependency_method(depname):
|
|||
try:
|
||||
builder = getattr(hs, "build_%s" % (depname))
|
||||
except AttributeError:
|
||||
builder = None
|
||||
raise NotImplementedError(
|
||||
"%s has no %s nor a builder for it" % (type(hs).__name__, depname)
|
||||
)
|
||||
|
||||
if builder:
|
||||
# Prevent cyclic dependencies from deadlocking
|
||||
if depname in hs._building:
|
||||
raise ValueError("Cyclic dependency while building %s" % (depname,))
|
||||
hs._building[depname] = 1
|
||||
# Prevent cyclic dependencies from deadlocking
|
||||
if depname in hs._building:
|
||||
raise ValueError("Cyclic dependency while building %s" % (depname,))
|
||||
|
||||
hs._building[depname] = 1
|
||||
try:
|
||||
dep = builder()
|
||||
setattr(hs, depname, dep)
|
||||
|
||||
finally:
|
||||
del hs._building[depname]
|
||||
|
||||
return dep
|
||||
|
||||
raise NotImplementedError(
|
||||
"%s has no %s nor a builder for it" % (type(hs).__name__, depname)
|
||||
)
|
||||
return dep
|
||||
|
||||
setattr(HomeServer, "get_%s" % (depname), _get)
|
||||
|
||||
|
|
|
@ -114,3 +114,5 @@ class HomeServer(object):
|
|||
pass
|
||||
def is_mine_id(self, domain_id: str) -> bool:
|
||||
pass
|
||||
def get_instance_id(self) -> str:
|
||||
pass
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
<body onload="matrixLogin.onLoad()">
|
||||
<center>
|
||||
<br/>
|
||||
<h1>Log in with one of the following methods</h1>
|
||||
<h1 id="title"></h1>
|
||||
|
||||
<span id="feedback" style="color: #f00"></span>
|
||||
|
||||
|
|
|
@ -1,37 +1,41 @@
|
|||
window.matrixLogin = {
|
||||
endpoint: location.origin + "/_matrix/client/r0/login",
|
||||
serverAcceptsPassword: false,
|
||||
serverAcceptsCas: false,
|
||||
serverAcceptsSso: false,
|
||||
};
|
||||
|
||||
var title_pre_auth = "Log in with one of the following methods";
|
||||
var title_post_auth = "Logging in...";
|
||||
|
||||
var submitPassword = function(user, pwd) {
|
||||
console.log("Logging in with password...");
|
||||
set_title(title_post_auth);
|
||||
var data = {
|
||||
type: "m.login.password",
|
||||
user: user,
|
||||
password: pwd,
|
||||
};
|
||||
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
|
||||
show_login();
|
||||
matrixLogin.onLogin(response);
|
||||
}).error(errorFunc);
|
||||
};
|
||||
|
||||
var submitToken = function(loginToken) {
|
||||
console.log("Logging in with login token...");
|
||||
set_title(title_post_auth);
|
||||
var data = {
|
||||
type: "m.login.token",
|
||||
token: loginToken
|
||||
};
|
||||
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
|
||||
show_login();
|
||||
matrixLogin.onLogin(response);
|
||||
}).error(errorFunc);
|
||||
};
|
||||
|
||||
var errorFunc = function(err) {
|
||||
show_login();
|
||||
// We want to show the error to the user rather than redirecting immediately to the
|
||||
// SSO portal (if SSO is the only login option), so we inhibit the redirect.
|
||||
show_login(true);
|
||||
|
||||
if (err.responseJSON && err.responseJSON.error) {
|
||||
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
|
||||
|
@ -45,26 +49,33 @@ var setFeedbackString = function(text) {
|
|||
$("#feedback").text(text);
|
||||
};
|
||||
|
||||
var show_login = function() {
|
||||
$("#loading").hide();
|
||||
|
||||
var show_login = function(inhibit_redirect) {
|
||||
var this_page = window.location.origin + window.location.pathname;
|
||||
$("#sso_redirect_url").val(this_page);
|
||||
|
||||
// If inhibit_redirect is false, and SSO is the only supported login method, we can
|
||||
// redirect straight to the SSO page
|
||||
if (matrixLogin.serverAcceptsSso) {
|
||||
if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) {
|
||||
$("#sso_form").submit();
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, show the SSO form
|
||||
$("#sso_form").show();
|
||||
}
|
||||
|
||||
if (matrixLogin.serverAcceptsPassword) {
|
||||
$("#password_flow").show();
|
||||
}
|
||||
|
||||
if (matrixLogin.serverAcceptsSso) {
|
||||
$("#sso_flow").show();
|
||||
} else if (matrixLogin.serverAcceptsCas) {
|
||||
$("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect");
|
||||
$("#sso_flow").show();
|
||||
}
|
||||
|
||||
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) {
|
||||
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) {
|
||||
$("#no_login_types").show();
|
||||
}
|
||||
|
||||
set_title(title_pre_auth);
|
||||
|
||||
$("#loading").hide();
|
||||
};
|
||||
|
||||
var show_spinner = function() {
|
||||
|
@ -74,17 +85,15 @@ var show_spinner = function() {
|
|||
$("#loading").show();
|
||||
};
|
||||
|
||||
var set_title = function(title) {
|
||||
$("#title").text(title);
|
||||
};
|
||||
|
||||
var fetch_info = function(cb) {
|
||||
$.get(matrixLogin.endpoint, function(response) {
|
||||
var serverAcceptsPassword = false;
|
||||
var serverAcceptsCas = false;
|
||||
for (var i=0; i<response.flows.length; i++) {
|
||||
var flow = response.flows[i];
|
||||
if ("m.login.cas" === flow.type) {
|
||||
matrixLogin.serverAcceptsCas = true;
|
||||
console.log("Server accepts CAS");
|
||||
}
|
||||
if ("m.login.sso" === flow.type) {
|
||||
matrixLogin.serverAcceptsSso = true;
|
||||
console.log("Server accepts SSO");
|
||||
|
@ -102,7 +111,7 @@ var fetch_info = function(cb) {
|
|||
matrixLogin.onLoad = function() {
|
||||
fetch_info(function() {
|
||||
if (!try_token()) {
|
||||
show_login();
|
||||
show_login(false);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
|
|
@ -90,8 +90,10 @@ class BackgroundUpdater(object):
|
|||
self._clock = hs.get_clock()
|
||||
self.db = database
|
||||
|
||||
# if a background update is currently running, its name.
|
||||
self._current_background_update = None # type: Optional[str]
|
||||
|
||||
self._background_update_performance = {}
|
||||
self._background_update_queue = []
|
||||
self._background_update_handlers = {}
|
||||
self._all_done = False
|
||||
|
||||
|
@ -111,7 +113,7 @@ class BackgroundUpdater(object):
|
|||
except Exception:
|
||||
logger.exception("Error doing update")
|
||||
else:
|
||||
if result is None:
|
||||
if result:
|
||||
logger.info(
|
||||
"No more background updates to do."
|
||||
" Unscheduling background update task."
|
||||
|
@ -119,26 +121,25 @@ class BackgroundUpdater(object):
|
|||
self._all_done = True
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def has_completed_background_updates(self):
|
||||
async def has_completed_background_updates(self) -> bool:
|
||||
"""Check if all the background updates have completed
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if all background updates have completed
|
||||
True if all background updates have completed
|
||||
"""
|
||||
# if we've previously determined that there is nothing left to do, that
|
||||
# is easy
|
||||
if self._all_done:
|
||||
return True
|
||||
|
||||
# obviously, if we have things in our queue, we're not done.
|
||||
if self._background_update_queue:
|
||||
# obviously, if we are currently processing an update, we're not done.
|
||||
if self._current_background_update:
|
||||
return False
|
||||
|
||||
# otherwise, check if there are updates to be run. This is important,
|
||||
# as we may be running on a worker which doesn't perform the bg updates
|
||||
# itself, but still wants to wait for them to happen.
|
||||
updates = yield self.db.simple_select_onecol(
|
||||
updates = await self.db.simple_select_onecol(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcol="1",
|
||||
|
@ -153,11 +154,10 @@ class BackgroundUpdater(object):
|
|||
async def has_completed_background_update(self, update_name) -> bool:
|
||||
"""Check if the given background update has finished running.
|
||||
"""
|
||||
|
||||
if self._all_done:
|
||||
return True
|
||||
|
||||
if update_name in self._background_update_queue:
|
||||
if update_name == self._current_background_update:
|
||||
return False
|
||||
|
||||
update_exists = await self.db.simple_select_one_onecol(
|
||||
|
@ -170,9 +170,7 @@ class BackgroundUpdater(object):
|
|||
|
||||
return not update_exists
|
||||
|
||||
async def do_next_background_update(
|
||||
self, desired_duration_ms: float
|
||||
) -> Optional[int]:
|
||||
async def do_next_background_update(self, desired_duration_ms: float) -> bool:
|
||||
"""Does some amount of work on the next queued background update
|
||||
|
||||
Returns once some amount of work is done.
|
||||
|
@ -181,33 +179,51 @@ class BackgroundUpdater(object):
|
|||
desired_duration_ms(float): How long we want to spend
|
||||
updating.
|
||||
Returns:
|
||||
None if there is no more work to do, otherwise an int
|
||||
True if we have finished running all the background updates, otherwise False
|
||||
"""
|
||||
if not self._background_update_queue:
|
||||
updates = await self.db.simple_select_list(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcols=("update_name", "depends_on"),
|
||||
|
||||
def get_background_updates_txn(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT update_name, depends_on FROM background_updates
|
||||
ORDER BY ordering, update_name
|
||||
"""
|
||||
)
|
||||
in_flight = {update["update_name"] for update in updates}
|
||||
for update in updates:
|
||||
if update["depends_on"] not in in_flight:
|
||||
self._background_update_queue.append(update["update_name"])
|
||||
return self.db.cursor_to_dict(txn)
|
||||
|
||||
if not self._background_update_queue:
|
||||
# no work left to do
|
||||
return None
|
||||
if not self._current_background_update:
|
||||
all_pending_updates = await self.db.runInteraction(
|
||||
"background_updates", get_background_updates_txn,
|
||||
)
|
||||
if not all_pending_updates:
|
||||
# no work left to do
|
||||
return True
|
||||
|
||||
# pop from the front, and add back to the back
|
||||
update_name = self._background_update_queue.pop(0)
|
||||
self._background_update_queue.append(update_name)
|
||||
# find the first update which isn't dependent on another one in the queue.
|
||||
pending = {update["update_name"] for update in all_pending_updates}
|
||||
for upd in all_pending_updates:
|
||||
depends_on = upd["depends_on"]
|
||||
if not depends_on or depends_on not in pending:
|
||||
break
|
||||
logger.info(
|
||||
"Not starting on bg update %s until %s is done",
|
||||
upd["update_name"],
|
||||
depends_on,
|
||||
)
|
||||
else:
|
||||
# if we get to the end of that for loop, there is a problem
|
||||
raise Exception(
|
||||
"Unable to find a background update which doesn't depend on "
|
||||
"another: dependency cycle?"
|
||||
)
|
||||
|
||||
res = await self._do_background_update(update_name, desired_duration_ms)
|
||||
return res
|
||||
self._current_background_update = upd["update_name"]
|
||||
|
||||
async def _do_background_update(
|
||||
self, update_name: str, desired_duration_ms: float
|
||||
) -> int:
|
||||
await self._do_background_update(desired_duration_ms)
|
||||
return False
|
||||
|
||||
async def _do_background_update(self, desired_duration_ms: float) -> int:
|
||||
update_name = self._current_background_update
|
||||
logger.info("Starting update batch on background update '%s'", update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
|
@ -400,27 +416,6 @@ class BackgroundUpdater(object):
|
|||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
||||
def start_background_update(self, update_name, progress):
|
||||
"""Starts a background update running.
|
||||
|
||||
Args:
|
||||
update_name: The update to set running.
|
||||
progress: The initial state of the progress of the update.
|
||||
|
||||
Returns:
|
||||
A deferred that completes once the task has been added to the
|
||||
queue.
|
||||
"""
|
||||
# Clear the background update queue so that we will pick up the new
|
||||
# task on the next iteration of do_background_update.
|
||||
self._background_update_queue = []
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
return self.db.simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": update_name, "progress_json": progress_json},
|
||||
)
|
||||
|
||||
def _end_background_update(self, update_name):
|
||||
"""Removes a completed background update task from the queue.
|
||||
|
||||
|
@ -429,9 +424,12 @@ class BackgroundUpdater(object):
|
|||
Returns:
|
||||
A deferred that completes once the task is removed.
|
||||
"""
|
||||
self._background_update_queue = [
|
||||
name for name in self._background_update_queue if name != update_name
|
||||
]
|
||||
if update_name != self._current_background_update:
|
||||
raise Exception(
|
||||
"Cannot end background update %s which isn't currently running"
|
||||
% update_name
|
||||
)
|
||||
self._current_background_update = None
|
||||
return self.db.simple_delete_one(
|
||||
"background_updates", keyvalues={"update_name": update_name}
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue