Merge branch 'develop' into babolivier/device_list_retry

pull/7453/head
Brendan Abolivier 2020-05-18 17:18:24 +02:00
commit 3a94f4a2b5
83 changed files with 1507 additions and 881 deletions

5
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,5 @@
**If you are looking for support** please ask in **#synapse:matrix.org**
(using a matrix.org account if necessary). We do not use GitHub issues for
support.
**If you want to report a security issue** please see https://matrix.org/security-disclosure-policy/

View File

@ -4,11 +4,13 @@ about: Create a report to help us improve
--- ---
**THIS IS NOT A SUPPORT CHANNEL!**
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**,
please ask in **#synapse:matrix.org** (using a matrix.org account if necessary)
<!-- <!--
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**: If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/
You will likely get better support more quickly if you ask in ** #synapse:matrix.org ** ;)
This is a bug report template. By following the instructions below and This is a bug report template. By following the instructions below and
filling out the sections with your information, you will help the us to get all filling out the sections with your information, you will help the us to get all

View File

@ -1,3 +1,12 @@
Synapse 1.13.0rc3 (2020-05-18)
==============================
Bugfixes
--------
- Hash passwords as early as possible during registration. ([\#7523](https://github.com/matrix-org/synapse/issues/7523))
Synapse 1.13.0rc2 (2020-05-14) Synapse 1.13.0rc2 (2020-05-14)
============================== ==============================

View File

@ -1,19 +1,20 @@
# Contributing code to Matrix # Contributing code to Synapse
Everyone is welcome to contribute code to Matrix Everyone is welcome to contribute code to [matrix.org
(https://github.com/matrix-org), provided that they are willing to license projects](https://github.com/matrix-org), provided that they are willing to
their contributions under the same license as the project itself. We follow a license their contributions under the same license as the project itself. We
simple 'inbound=outbound' model for contributions: the act of submitting an follow a simple 'inbound=outbound' model for contributions: the act of
'inbound' contribution means that the contributor agrees to license the code submitting an 'inbound' contribution means that the contributor agrees to
under the same terms as the project's overall 'outbound' license - in our license the code under the same terms as the project's overall 'outbound'
case, this is almost always Apache Software License v2 (see [LICENSE](LICENSE)). license - in our case, this is almost always Apache Software License v2 (see
[LICENSE](LICENSE)).
## How to contribute ## How to contribute
The preferred and easiest way to contribute changes to Matrix is to fork the The preferred and easiest way to contribute changes is to fork the relevant
relevant project on github, and then [create a pull request]( project on github, and then [create a pull request](
https://help.github.com/articles/using-pull-requests/) to ask us to pull https://help.github.com/articles/using-pull-requests/) to ask us to pull your
your changes into our repo. changes into our repo.
**The single biggest thing you need to know is: please base your changes on **The single biggest thing you need to know is: please base your changes on
the develop branch - *not* master.** the develop branch - *not* master.**
@ -28,35 +29,31 @@ use github's pull request workflow to review the contribution, and either ask
you to make any refinements needed or merge it and make them ourselves. The you to make any refinements needed or merge it and make them ourselves. The
changes will then land on master when we next do a release. changes will then land on master when we next do a release.
We use [Buildkite](https://buildkite.com/matrix-dot-org/synapse) for continuous Some other things you will need to know when contributing to Synapse:
integration. If your change breaks the build, this will be shown in GitHub, so
please keep an eye on the pull request for feedback.
To run unit tests in a local development environment, you can use: * Please follow the [code style requirements](#code-style).
- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) * Please include a [changelog entry](#changelog) with each PR.
for SQLite-backed Synapse on Python 3.5.
- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
(requires a running local PostgreSQL with access to create databases).
- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
(requires Docker). Entirely self-contained, recommended if you don't want to
set up PostgreSQL yourself.
Docker images are available for running the integration tests (SyTest) locally, * Please [sign off](#sign-off) your contribution.
see the [documentation in the SyTest repo](
https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more * Please keep an eye on the pull request for feedback from the [continuous
information. integration system](#continuous-integration-and-testing) and try to fix any
errors that come up.
* If you need to [update your PR](#updating-your-pull-request), just add new
commits to your branch rather than rebasing.
## Code style ## Code style
All Matrix projects have a well-defined code-style - and sometimes we've even Synapse's code style is documented [here](docs/code_style.md). Please follow
got as far as documenting it... For instance, synapse's code style doc lives it, including the conventions for the [sample configuration
[here](docs/code_style.md). file](docs/code_style.md#configuration-file-format).
To facilitate meeting these criteria you can run `scripts-dev/lint.sh` Many of the conventions are enforced by scripts which are run as part of the
locally. Since this runs the tools listed in the above document, you'll need [continuous integration system](#continuous-integration-and-testing). To help
python 3.6 and to install each tool: check if you have followed the code style, you can run `scripts-dev/lint.sh`
locally. You'll need python 3.6 or later, and to install a number of tools:
``` ```
# Install the dependencies # Install the dependencies
@ -67,9 +64,11 @@ pip install -U black flake8 flake8-comprehensions isort
``` ```
**Note that the script does not just test/check, but also reformats code, so you **Note that the script does not just test/check, but also reformats code, so you
may wish to ensure any new code is committed first**. By default this script may wish to ensure any new code is committed first**.
checks all files and can take some time; if you alter only certain files, you
might wish to specify paths as arguments to reduce the run-time: By default, this script checks all files and can take some time; if you alter
only certain files, you might wish to specify paths as arguments to reduce the
run-time:
``` ```
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
@ -82,7 +81,6 @@ Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise. makes it horribly hard to review otherwise.
## Changelog ## Changelog
All changes, even minor ones, need a corresponding changelog / newsfragment All changes, even minor ones, need a corresponding changelog / newsfragment
@ -98,24 +96,55 @@ in the format of `PRnumber.type`. The type can be one of the following:
* `removal` (also used for deprecations) * `removal` (also used for deprecations)
* `misc` (for internal-only changes) * `misc` (for internal-only changes)
The content of the file is your changelog entry, which should be a short This file will become part of our [changelog](
description of your change in the same style as the rest of our [changelog]( https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next
https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can release, so the content of the file should be a short description of your
contain Markdown formatting, and should end with a full stop (.) or an change in the same style as the rest of the changelog. The file can contain Markdown
exclamation mark (!) for consistency. formatting, and should end with a full stop (.) or an exclamation mark (!) for
consistency.
Adding credits to the changelog is encouraged, we value your Adding credits to the changelog is encouraged, we value your
contributions and would like to have you shouted out in the release notes! contributions and would like to have you shouted out in the release notes!
For example, a fix in PR #1234 would have its changelog entry in For example, a fix in PR #1234 would have its changelog entry in
`changelog.d/1234.bugfix`, and contain content like "The security levels of `changelog.d/1234.bugfix`, and contain content like:
Florbs are now validated when received over federation. Contributed by Jane
Matrix.".
## Debian changelog > The security levels of Florbs are now validated when received
> via the `/federation/florb` endpoint. Contributed by Jane Matrix.
If there are multiple pull requests involved in a single bugfix/feature/etc,
then the content for each `changelog.d` file should be the same. Towncrier will
merge the matching files together into a single changelog entry when we come to
release.
### How do I know what to call the changelog file before I create the PR?
Obviously, you don't know if you should call your newsfile
`1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a
chicken-and-egg problem.
There are two options for solving this:
1. Open the PR without a changelog file, see what number you got, and *then*
add the changelog file to your branch (see [Updating your pull
request](#updating-your-pull-request)), or:
1. Look at the [list of all
issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the
highest number you see, and quickly open the PR before somebody else claims
your number.
[This
script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh)
might be helpful if you find yourself doing this a lot.
Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come
to put together a release!
### Debian changelog
Changes which affect the debian packaging files (in `debian`) are an Changes which affect the debian packaging files (in `debian`) are an
exception. exception to the rule that all changes require a `changelog.d` file.
In this case, you will need to add an entry to the debian changelog for the In this case, you will need to add an entry to the debian changelog for the
next release. For this, run the following command: next release. For this, run the following command:
@ -200,19 +229,45 @@ Git allows you to add this signoff automatically when using the `-s`
flag to `git commit`, which uses the name and email set in your flag to `git commit`, which uses the name and email set in your
`user.name` and `user.email` git configs. `user.name` and `user.email` git configs.
## Merge Strategy ## Continuous integration and testing
We use the commit history of develop/master extensively to identify [Buildkite](https://buildkite.com/matrix-dot-org/synapse) will automatically
when regressions were introduced and what changes have been made. run a series of checks and tests against any PR which is opened against the
project; if your change breaks the build, this will be shown in GitHub, with
links to the build results. If your build fails, please try to fix the errors
and update your branch.
We aim to have a clean merge history, which means we normally squash-merge To run unit tests in a local development environment, you can use:
changes into develop. For small changes this means there is no need to rebase
to clean up your PR before merging. Larger changes with an organised set of
commits may be merged as-is, if the history is judged to be useful.
This use of squash-merging will mean PRs built on each other will be hard to - ``tox -e py35`` (requires tox to be installed by ``pip install tox``)
merge. We suggest avoiding these where possible, and if required, ensuring for SQLite-backed Synapse on Python 3.5.
each PR has a tidy set of commits to ease merging. - ``tox -e py36`` for SQLite-backed Synapse on Python 3.6.
- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6
(requires a running local PostgreSQL with access to create databases).
- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5
(requires Docker). Entirely self-contained, recommended if you don't want to
set up PostgreSQL yourself.
Docker images are available for running the integration tests (SyTest) locally,
see the [documentation in the SyTest repo](
https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more
information.
## Updating your pull request
If you decide to make changes to your pull request - perhaps to address issues
raised in a review, or to fix problems highlighted by [continuous
integration](#continuous-integration-and-testing) - just add new commits to your
branch, and push to GitHub. The pull request will automatically be updated.
Please **avoid** rebasing your branch, especially once the PR has been
reviewed: doing so makes it very difficult for a reviewer to see what has
changed since a previous review.
## Notes for maintainers on merging PRs etc
There are some notes for those with commit access to the project on how we
manage git [here](docs/dev/git.md).
## Conclusion ## Conclusion

View File

@ -1,3 +1,11 @@
================
Synapse |shield|
================
.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
:alt: (get support on #synapse:matrix.org)
:target: https://matrix.to/#/#synapse:matrix.org
.. contents:: .. contents::
Introduction Introduction
@ -77,6 +85,17 @@ Thanks for using Matrix!
[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_. [1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_.
Support
=======
For support installing or managing Synapse, please join |room|_ (from a matrix.org
account if necessary) and ask questions there. We do not use GitHub issues for
support requests, only for bug reports and feature requests.
.. |room| replace:: ``#synapse:matrix.org``
.. _room: https://matrix.to/#/#synapse:matrix.org
Synapse Installation Synapse Installation
==================== ====================

1
changelog.d/7381.bugfix Normal file
View File

@ -0,0 +1 @@
Add an experimental room version which strictly adheres to the canonical JSON specification.

1
changelog.d/7384.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

1
changelog.d/7457.feature Normal file
View File

@ -0,0 +1 @@
Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).

1
changelog.d/7463.doc Normal file
View File

@ -0,0 +1 @@
Add additional reverse proxy example for Caddy v2. Contributed by Jeff Peeler.

1
changelog.d/7465.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent rooms with 0 members or with invalid version strings from breaking group queries.

1
changelog.d/7491.misc Normal file
View File

@ -0,0 +1 @@
Move event stream handling out of slave store.

1
changelog.d/7505.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.event_auth`.

1
changelog.d/7506.feature Normal file
View File

@ -0,0 +1 @@
Implement room version 6 per [MSC2240](https://github.com/matrix-org/matrix-doc/pull/2240).

1
changelog.d/7507.misc Normal file
View File

@ -0,0 +1 @@
Convert the room member handler to async/await.

1
changelog.d/7508.bugfix Normal file
View File

@ -0,0 +1 @@
Ignore incoming presence events from other homeservers if presence is disabled locally.

1
changelog.d/7511.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug that broke the update remote profile background process.

1
changelog.d/7513.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to room member handler.

1
changelog.d/7514.doc Normal file
View File

@ -0,0 +1 @@
Improve the formatting of `reverse_proxy.md`.

1
changelog.d/7515.misc Normal file
View File

@ -0,0 +1 @@
Allow `ReplicationRestResource` to be added to workers.

1
changelog.d/7516.misc Normal file
View File

@ -0,0 +1 @@
Add a worker store for search insertion, required for moving event persistence off master.

1
changelog.d/7518.misc Normal file
View File

@ -0,0 +1 @@
Fix typing annotations in `tests.replication`.

1
changelog.d/7519.misc Normal file
View File

@ -0,0 +1 @@
Remove some redundant Python 2 support code.

148
docs/dev/git.md Normal file
View File

@ -0,0 +1,148 @@
Some notes on how we use git
============================
On keeping the commit history clean
-----------------------------------
In an ideal world, our git commit history would be a linear progression of
commits each of which contains a single change building on what came
before. Here, by way of an arbitrary example, is the top of `git log --graph
b2dba0607`:
<img src="git/clean.png" alt="clean git graph" width="500px">
Note how the commit comment explains clearly what is changing and why. Also
note the *absence* of merge commits, as well as the absence of commits called
things like (to pick a few culprits):
[“pep8”](https://github.com/matrix-org/synapse/commit/84691da6c), [“fix broken
test”](https://github.com/matrix-org/synapse/commit/474810d9d),
[“oops”](https://github.com/matrix-org/synapse/commit/c9d72e457),
[“typo”](https://github.com/matrix-org/synapse/commit/836358823), or [“Who's
the president?”](https://github.com/matrix-org/synapse/commit/707374d5d).
There are a number of reasons why keeping a clean commit history is a good
thing:
* From time to time, after a change lands, it turns out to be necessary to
revert it, or to backport it to a release branch. Those operations are
*much* easier when the change is contained in a single commit.
* Similarly, it's much easier to answer questions like “is the fix for
`/publicRooms` on the release branch?” if that change consists of a single
commit.
* Likewise: “what has changed on this branch in the last week?” is much
clearer without merges and “pep8” commits everywhere.
* Sometimes we need to figure out where a bug got introduced, or some
behaviour changed. One way of doing that is with `git bisect`: pick an
arbitrary commit between the known good point and the known bad point, and
see how the code behaves. However, that strategy fails if the commit you
chose is the middle of someone's epic branch in which they broke the world
before putting it back together again.
One counterargument is that it is sometimes useful to see how a PR evolved as
it went through review cycles. This is true, but that information is always
available via the GitHub UI (or via the little-known [refs/pull
namespace](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/checking-out-pull-requests-locally)).
Of course, in reality, things are more complicated than that. We have release
branches as well as `develop` and `master`, and we deliberately merge changes
between them. Bugs often slip through and have to be fixed later. That's all
fine: this not a cast-iron rule which must be obeyed, but an ideal to aim
towards.
Merges, squashes, rebases: wtf?
-------------------------------
Ok, so that's what we'd like to achieve. How do we achieve it?
The TL;DR is: when you come to merge a pull request, you *probably* want to
“squash and merge”:
![squash and merge](git/squash.png).
(This applies whether you are merging your own PR, or that of another
contributor.)
“Squash and merge”<sup id="a1">[1](#f1)</sup> takes all of the changes in the
PR, and bundles them into a single commit. GitHub gives you the opportunity to
edit the commit message before you confirm, and normally you should do so,
because the default will be useless (again: `* woops typo` is not a useful
thing to keep in the historical record).
The main problem with this approach comes when you have a series of pull
requests which build on top of one another: as soon as you squash-merge the
first PR, you'll end up with a stack of conflicts to resolve in all of the
others. In general, it's best to avoid this situation in the first place by
trying not to have multiple related PRs in flight at the same time. Still,
sometimes that's not possible and doing a regular merge is the lesser evil.
Another occasion in which a regular merge makes more sense is a PR where you've
deliberately created a series of commits each of which makes sense in its own
right. For example: [a PR which gradually propagates a refactoring operation
through the codebase](https://github.com/matrix-org/synapse/pull/6837), or [a
PR which is the culmination of several other
PRs](https://github.com/matrix-org/synapse/pull/5987). In this case the ability
to figure out when a particular change/bug was introduced could be very useful.
Ultimately: **this is not a hard-and-fast-rule**. If in doubt, ask yourself “do
each of the commits I am about to merge make sense in their own right”, but
remember that we're just doing our best to balance “keeping the commit history
clean” with other factors.
Git branching model
-------------------
A [lot](https://nvie.com/posts/a-successful-git-branching-model/)
[of](http://scottchacon.com/2011/08/31/github-flow.html)
[words](https://www.endoflineblog.com/gitflow-considered-harmful) have been
written in the past about git branching models (no really, [a
lot](https://martinfowler.com/articles/branching-patterns.html)). I tend to
think the whole thing is overblown. Fundamentally, it's not that
complicated. Here's how we do it.
Let's start with a picture:
![branching model](git/branches.jpg)
It looks complicated, but it's really not. There's one basic rule: *anyone* is
free to merge from *any* more-stable branch to *any* less-stable branch at
*any* time<sup id="a2">[2](#f2)</sup>. (The principle behind this is that if a
change is good enough for the more-stable branch, then it's also good enough go
put in a less-stable branch.)
Meanwhile, merging (or squashing, as per the above) from a less-stable to a
more-stable branch is a deliberate action in which you want to publish a change
or a set of changes to (some subset of) the world: for example, this happens
when a PR is landed, or as part of our release process.
So, what counts as a more- or less-stable branch? A little reflection will show
that our active branches are ordered thus, from more-stable to less-stable:
* `master` (tracks our last release).
* `release-vX.Y.Z` (the branch where we prepare the next release)<sup
id="a3">[3](#f3)</sup>.
* PR branches which are targeting the release.
* `develop` (our "mainline" branch containing our bleeding-edge).
* regular PR branches.
The corollary is: if you have a bugfix that needs to land in both
`release-vX.Y.Z` *and* `develop`, then you should base your PR on
`release-vX.Y.Z`, get it merged there, and then merge from `release-vX.Y.Z` to
`develop`. (If a fix lands in `develop` and we later need it in a
release-branch, we can of course cherry-pick it, but landing it in the release
branch first helps reduce the chance of annoying conflicts.)
---
<b id="f1">[1]</b>: “Squash and merge” is GitHub's term for this
operation. Given that there is no merge involved, I'm not convinced it's the
most intuitive name. [^](#a1)
<b id="f2">[2]</b>: Well, anyone with commit access.[^](#a2)
<b id="f3">[3]</b>: Very, very occasionally (I think this has happened once in
the history of Synapse), we've had two releases in flight at once. Obviously,
`release-v1.2.3` is more-stable than `release-v1.3.0`. [^](#a3)

BIN
docs/dev/git/branches.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

BIN
docs/dev/git/clean.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

BIN
docs/dev/git/squash.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@ -9,7 +9,7 @@ of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root (443) to Matrix clients without needing to run Synapse with root
privileges. privileges.
> **NOTE**: Your reverse proxy must not `canonicalise` or `normalise` **NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
the requested URI in any way (for example, by decoding `%xx` escapes). the requested URI in any way (for example, by decoding `%xx` escapes).
Beware that Apache *will* canonicalise URIs unless you specifify Beware that Apache *will* canonicalise URIs unless you specifify
`nocanon`. `nocanon`.
@ -18,7 +18,7 @@ When setting up a reverse proxy, remember that Matrix clients and other
Matrix servers do not necessarily need to connect to your server via the Matrix servers do not necessarily need to connect to your server via the
same server name or port. Indeed, clients will use port 443 by default, same server name or port. Indeed, clients will use port 443 by default,
whereas servers default to port 8448. Where these are different, we whereas servers default to port 8448. Where these are different, we
refer to the 'client port' and the \'federation port\'. See [the Matrix refer to the 'client port' and the 'federation port'. See [the Matrix
specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names)
for more details of the algorithm used for federation connections, and for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation. [delegate.md](<delegate.md>) for instructions on setting up delegation.
@ -28,12 +28,13 @@ Let's assume that we expect clients to connect to our server at
`https://example.com:8448`. The following sections detail the configuration of `https://example.com:8448`. The following sections detail the configuration of
the reverse proxy and the homeserver. the reverse proxy and the homeserver.
## Webserver configuration examples ## Reverse-proxy configuration examples
> **NOTE**: You only need one of these. **NOTE**: You only need one of these.
### nginx ### nginx
```
server { server {
listen 443 ssl; listen 443 ssl;
listen [::]:443 ssl; listen [::]:443 ssl;
@ -58,12 +59,14 @@ the reverse proxy and the homeserver.
proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-For $remote_addr;
} }
} }
```
> **NOTE**: Do not add a `/` after the port in `proxy_pass`, otherwise nginx will **NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
canonicalise/normalise the URI. canonicalise/normalise the URI.
### Caddy ### Caddy 1
```
matrix.example.com { matrix.example.com {
proxy /_matrix http://localhost:8008 { proxy /_matrix http://localhost:8008 {
transparent transparent
@ -75,9 +78,23 @@ canonicalise/normalise the URI.
transparent transparent
} }
} }
```
### Caddy 2
```
matrix.example.com {
reverse_proxy /_matrix/* http://localhost:8008
}
example.com:8448 {
reverse_proxy http://localhost:8008
}
```
### Apache ### Apache
```
<VirtualHost *:443> <VirtualHost *:443>
SSLEngine on SSLEngine on
ServerName matrix.example.com; ServerName matrix.example.com;
@ -95,11 +112,13 @@ canonicalise/normalise the URI.
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
</VirtualHost> </VirtualHost>
```
> **NOTE**: ensure the `nocanon` options are included. **NOTE**: ensure the `nocanon` options are included.
### HAProxy ### HAProxy
```
frontend https frontend https
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
@ -115,6 +134,7 @@ canonicalise/normalise the URI.
backend matrix backend matrix
server matrix 127.0.0.1:8008 server matrix 127.0.0.1:8008
```
## Homeserver Configuration ## Homeserver Configuration

View File

@ -3,8 +3,6 @@ import json
import sys import sys
import time import time
import six
import psycopg2 import psycopg2
import yaml import yaml
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
if six.PY2: db_binary_type = memoryview
db_type = six.moves.builtins.buffer
else:
db_type = memoryview
def select_v1_keys(connection): def select_v1_keys(connection):
@ -72,7 +67,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"] valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json) key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]: for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json))
def main(): def main():

View File

@ -36,7 +36,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.13.0rc2" __version__ = "1.13.0rc3"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -59,7 +59,11 @@ class RoomVersion(object):
# bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool) special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
# * Floats
# * NaN, Infinity, -Infinity
strict_canonicaljson = attr.ib(type=bool)
# bool: MSC2209: Check 'notifications' key while verifying # bool: MSC2209: Check 'notifications' key while verifying
# m.room.power_levels auth rules. # m.room.power_levels auth rules.
limit_notifications_power_levels = attr.ib(type=bool) limit_notifications_power_levels = attr.ib(type=bool)
@ -73,6 +77,7 @@ class RoomVersions(object):
StateResolutionVersions.V1, StateResolutionVersions.V1,
enforce_key_validity=False, enforce_key_validity=False,
special_case_aliases_auth=True, special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
) )
V2 = RoomVersion( V2 = RoomVersion(
@ -82,6 +87,7 @@ class RoomVersions(object):
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=False, enforce_key_validity=False,
special_case_aliases_auth=True, special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
) )
V3 = RoomVersion( V3 = RoomVersion(
@ -91,6 +97,7 @@ class RoomVersions(object):
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=False, enforce_key_validity=False,
special_case_aliases_auth=True, special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
) )
V4 = RoomVersion( V4 = RoomVersion(
@ -100,6 +107,7 @@ class RoomVersions(object):
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=False, enforce_key_validity=False,
special_case_aliases_auth=True, special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
) )
V5 = RoomVersion( V5 = RoomVersion(
@ -109,24 +117,17 @@ class RoomVersions(object):
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=True, enforce_key_validity=True,
special_case_aliases_auth=True, special_case_aliases_auth=True,
strict_canonicaljson=False,
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
) )
MSC2432_DEV = RoomVersion( V6 = RoomVersion(
"org.matrix.msc2432", "6",
RoomDisposition.UNSTABLE, RoomDisposition.STABLE,
EventFormatVersions.V3, EventFormatVersions.V3,
StateResolutionVersions.V2, StateResolutionVersions.V2,
enforce_key_validity=True, enforce_key_validity=True,
special_case_aliases_auth=False, special_case_aliases_auth=False,
limit_notifications_power_levels=False, strict_canonicaljson=True,
)
MSC2209_DEV = RoomVersion(
"org.matrix.msc2209",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=True,
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
) )
@ -139,7 +140,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V3, RoomVersions.V3,
RoomVersions.V4, RoomVersions.V4,
RoomVersions.V5, RoomVersions.V5,
RoomVersions.MSC2432_DEV, RoomVersions.V6,
RoomVersions.MSC2209_DEV,
) )
} # type: Dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View File

@ -47,6 +47,7 @@ from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@ -122,6 +123,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore, MonthlyActiveUsersWorkerStore,
) )
from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
@ -451,6 +453,7 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
MonthlyActiveUsersWorkerStore, MonthlyActiveUsersWorkerStore,
MediaRepositoryStore, MediaRepositoryStore,
SearchWorkerStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database, db_conn, hs): def __init__(self, database, db_conn, hs):
@ -568,6 +571,9 @@ class GenericWorkerServer(HomeServer):
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
root_resource = create_resource_tree(resources, NoResource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(

View File

@ -270,7 +270,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exlusive_user_regexes(self): def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively """Get the list of regexes used to determine if a user is exclusively
registered by the AS registered by the AS
""" """

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set, Tuple from typing import List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -29,18 +29,19 @@ from synapse.api.room_versions import (
EventFormatVersions, EventFormatVersions,
RoomVersion, RoomVersion,
) )
from synapse.types import UserID, get_domain_from_id from synapse.events import EventBase
from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check( def check(
room_version_obj: RoomVersion, room_version_obj: RoomVersion,
event, event: EventBase,
auth_events, auth_events: StateMap[EventBase],
do_sig_check=True, do_sig_check: bool = True,
do_size_check=True, do_size_check: bool = True,
): ) -> None:
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@ -189,7 +190,7 @@ def check(
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
def _check_size_limits(event): def _check_size_limits(event: EventBase) -> None:
def too_big(field): def too_big(field):
raise EventSizeError("%s too large" % (field,)) raise EventSizeError("%s too large" % (field,))
@ -207,13 +208,18 @@ def _check_size_limits(event):
too_big("event") too_big("event")
def _can_federate(event, auth_events): def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
creation_event = auth_events.get((EventTypes.Create, "")) creation_event = auth_events.get((EventTypes.Create, ""))
# There should always be a creation event, but if not don't federate.
if not creation_event:
return False
return creation_event.content.get("m.federate", True) is True return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events): def _is_membership_change_allowed(
event: EventBase, auth_events: StateMap[EventBase]
) -> None:
membership = event.content["membership"] membership = event.content["membership"]
# Check if this is the room creator joining: # Check if this is the room creator joining:
@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(500, "Unknown membership %s" % membership) raise AuthError(500, "Unknown membership %s" % membership)
def _check_event_sender_in_room(event, auth_events): def _check_event_sender_in_room(
event: EventBase, auth_events: StateMap[EventBase]
) -> None:
key = (EventTypes.Member, event.user_id) key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key) member_event = auth_events.get(key)
return _check_joined_room(member_event, event.user_id, event.room_id) _check_joined_room(member_event, event.user_id, event.room_id)
def _check_joined_room(member, user_id, room_id): def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
if not member or member.membership != Membership.JOIN: if not member or member.membership != Membership.JOIN:
raise AuthError( raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
) )
def get_send_level(etype, state_key, power_levels_event): def get_send_level(
etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
) -> int:
"""Get the power level required to send an event of a given type """Get the power level required to send an event of a given type
The federation spec [1] refers to this as "Required Power Level". The federation spec [1] refers to this as "Required Power Level".
@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event):
https://matrix.org/docs/spec/server_server/unstable.html#definitions https://matrix.org/docs/spec/server_server/unstable.html#definitions
Args: Args:
etype (str): type of event etype: type of event
state_key (str|None): state_key of state event, or None if it is not state_key: state_key of state event, or None if it is not
a state event. a state event.
power_levels_event (synapse.events.EventBase|None): power levels event power_levels_event: power levels event
in force at this point in the room in force at this point in the room
Returns: Returns:
int: power level required to send this event. power level required to send this event.
""" """
if power_levels_event: if power_levels_event:
@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event):
return int(send_level) return int(send_level)
def _can_send_event(event, auth_events): def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
power_levels_event = _get_power_level_event(auth_events) power_levels_event = _get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@ -410,7 +420,9 @@ def _can_send_event(event, auth_events):
return True return True
def check_redaction(room_version_obj: RoomVersion, event, auth_events): def check_redaction(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
) -> bool:
"""Check whether the event sender is allowed to redact the target event. """Check whether the event sender is allowed to redact the target event.
Returns: Returns:
@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events):
raise AuthError(403, "You don't have permission to redact events") raise AuthError(403, "You don't have permission to redact events")
def _check_power_levels(room_version_obj, event, auth_events): def _check_power_levels(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
) -> None:
user_list = event.content.get("users", {}) user_list = event.content.get("users", {})
# Validate users # Validate users
for k, v in user_list.items(): for k, v in user_list.items():
@ -473,7 +487,7 @@ def _check_power_levels(room_version_obj, event, auth_events):
("redact", None), ("redact", None),
("kick", None), ("kick", None),
("invite", None), ("invite", None),
] ] # type: List[Tuple[str, Optional[str]]]
old_list = current_state.content.get("users", {}) old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)): for user in set(list(old_list) + list(user_list)):
@ -503,12 +517,12 @@ def _check_power_levels(room_version_obj, event, auth_events):
new_loc = new_loc.get(dir, {}) new_loc = new_loc.get(dir, {})
if level_to_check in old_loc: if level_to_check in old_loc:
old_level = int(old_loc[level_to_check]) old_level = int(old_loc[level_to_check]) # type: Optional[int]
else: else:
old_level = None old_level = None
if level_to_check in new_loc: if level_to_check in new_loc:
new_level = int(new_loc[level_to_check]) new_level = int(new_loc[level_to_check]) # type: Optional[int]
else: else:
new_level = None new_level = None
@ -534,21 +548,21 @@ def _check_power_levels(room_version_obj, event, auth_events):
) )
def _get_power_level_event(auth_events): def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, "")) return auth_events.get((EventTypes.PowerLevels, ""))
def get_user_power_level(user_id, auth_events): def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
"""Get a user's power level """Get a user's power level
Args: Args:
user_id (str): user's id to look up in power_levels user_id: user's id to look up in power_levels
auth_events (dict[(str, str), synapse.events.EventBase]): auth_events:
state in force at this point in the room (or rather, a subset of state in force at this point in the room (or rather, a subset of
it including at least the create event and power levels event. it including at least the create event and power levels event.
Returns: Returns:
int: the user's power level in this room. the user's power level in this room.
""" """
power_level_event = _get_power_level_event(auth_events) power_level_event = _get_power_level_event(auth_events)
if power_level_event: if power_level_event:
@ -574,7 +588,7 @@ def get_user_power_level(user_id, auth_events):
return 0 return 0
def _get_named_level(auth_events, name, default): def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
power_level_event = _get_power_level_event(auth_events) power_level_event = _get_power_level_event(auth_events)
if not power_level_event: if not power_level_event:
@ -587,7 +601,7 @@ def _get_named_level(auth_events, name, default):
return default return default
def _verify_third_party_invite(event, auth_events): def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
""" """
Validates that the invite event is authorized by a previous third-party invite. Validates that the invite event is authorized by a previous third-party invite.
@ -662,7 +676,7 @@ def get_public_keys(invite_event):
return public_keys return public_keys
def auth_types_for_event(event) -> Set[Tuple[str, str]]: def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be """Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room. would actually be required depending on the full state of the room.

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
import re import re
from typing import Mapping, Union from typing import Any, Mapping, Union
from six import string_types from six import string_types
@ -23,6 +23,7 @@ from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
@ -449,3 +450,35 @@ def copy_power_levels_contents(
raise TypeError("Invalid power_levels value for %s: %r" % (k, v)) raise TypeError("Invalid power_levels value for %s: %r" % (k, v))
return power_levels return power_levels
def validate_canonicaljson(value: Any):
"""
Ensure that the JSON object is valid according to the rules of canonical JSON.
See the appendix section 3.1: Canonical JSON.
This rejects JSON that has:
* An integer outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
* Floats
* NaN, Infinity, -Infinity
"""
if isinstance(value, int):
if value <= -(2 ** 53) or 2 ** 53 <= value:
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
elif isinstance(value, float):
# Note that Infinity, -Infinity, and NaN are also considered floats.
raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
elif isinstance(value, (dict, frozendict)):
for v in value.values():
validate_canonicaljson(v)
elif isinstance(value, (list, tuple)):
for i in value:
validate_canonicaljson(i)
elif not isinstance(value, (bool, str)) and value is not None:
# Other potential JSON values (bool, None, str) are safe.
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)

View File

@ -18,6 +18,7 @@ from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions from synapse.api.room_versions import EventFormatVersions
from synapse.events.utils import validate_canonicaljson
from synapse.types import EventID, RoomID, UserID from synapse.types import EventID, RoomID, UserID
@ -55,6 +56,12 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types): if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "'%s' not a string type" % (s,)) raise SynapseError(400, "'%s' not a string type" % (s,))
# Depending on the room version, ensure the data is spec compliant JSON.
if event.room_version.strict_canonicaljson:
# Note that only the client controlled portion of the event is
# checked, since we trust the portions of the event we created.
validate_canonicaljson(event.content)
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if "aliases" in event.content: if "aliases" in event.content:
for alias in event.content["aliases"]: for alias in event.content["aliases"]:

View File

@ -29,7 +29,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
@ -302,6 +302,10 @@ def event_from_pdu_json(
elif depth > MAX_DEPTH: elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON) raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
# Validate that the JSON conforms to the specification.
if room_version.strict_canonicaljson:
validate_canonicaljson(pdu_json)
event = make_event_from_dict(pdu_json, room_version) event = make_event_from_dict(pdu_json, room_version)
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier

View File

@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled self._sso_enabled = (
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
# we keep this as a list despite the O(N^2) implication so that we can # we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first # keep PASSWORD first and avoid confusing clients which pick the first

View File

@ -311,7 +311,7 @@ class OidcHandler:
``ClientAuth`` to authenticate with the client with its ID and secret. ``ClientAuth`` to authenticate with the client with its ID and secret.
Args: Args:
code: The autorization code we got from the callback. code: The authorization code we got from the callback.
Returns: Returns:
A dict containing various tokens. A dict containing various tokens.
@ -497,11 +497,14 @@ class OidcHandler:
return UserInfo(claims) return UserInfo(claims)
async def handle_redirect_request( async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes self,
) -> None: request: SynapseRequest,
client_redirect_url: bytes,
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect """Handle an incoming request to /login/sso/redirect
It redirects the browser to the authorization endpoint with a few It returns a redirect to the authorization endpoint with a few
parameters: parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id`` - ``client_id``: the client ID set in ``oidc_config.client_id``
@ -511,24 +514,32 @@ class OidcHandler:
- ``state``: a random string - ``state``: a random string
- ``nonce``: a random string - ``nonce``: a random string
In addition to redirecting the client, we are setting a cookie with In addition generating a redirect URL, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client client_redirect_url params. Those are then checked when the client
comes back from the provider. comes back from the provider.
Args: Args:
request: the incoming request from the browser. request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie. We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to client_redirect_url: the URL that we should redirect the client to
when everything is done when everything is done
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
The redirect URL to the authorization endpoint.
""" """
state = generate_token() state = generate_token()
nonce = generate_token() nonce = generate_token()
cookie = self._generate_oidc_session_token( cookie = self._generate_oidc_session_token(
state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
) )
request.addCookie( request.addCookie(
SESSION_COOKIE_NAME, SESSION_COOKIE_NAME,
@ -541,7 +552,7 @@ class OidcHandler:
metadata = await self.load_metadata() metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint") authorization_endpoint = metadata.get("authorization_endpoint")
uri = prepare_grant_uri( return prepare_grant_uri(
authorization_endpoint, authorization_endpoint,
client_id=self._client_auth.client_id, client_id=self._client_auth.client_id,
response_type="code", response_type="code",
@ -550,8 +561,6 @@ class OidcHandler:
state=state, state=state,
nonce=nonce, nonce=nonce,
) )
request.redirect(uri)
finish_request(request)
async def handle_oidc_callback(self, request: SynapseRequest) -> None: async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback """Handle an incoming request to /_synapse/oidc/callback
@ -625,7 +634,11 @@ class OidcHandler:
# Deserialize the session token and verify it. # Deserialize the session token and verify it.
try: try:
nonce, client_redirect_url = self._verify_oidc_session_token(session, state) (
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e: except MacaroonDeserializationException as e:
logger.exception("Invalid session") logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e)) self._render_error(request, "invalid_session", str(e))
@ -678,6 +691,11 @@ class OidcHandler:
return return
# and finally complete the login # and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url user_id, request, client_redirect_url
) )
@ -687,6 +705,7 @@ class OidcHandler:
state: str, state: str,
nonce: str, nonce: str,
client_redirect_url: str, client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000), duration_in_ms: int = (60 * 60 * 1000),
) -> str: ) -> str:
"""Generates a signed token storing data about an OIDC session. """Generates a signed token storing data about an OIDC session.
@ -702,6 +721,8 @@ class OidcHandler:
nonce: The ``nonce`` parameter passed to the OIDC provider. nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the client_redirect_url: The URL the client gave when it initiated the
flow. flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds. duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour. Defaults to an hour.
@ -718,12 +739,19 @@ class OidcHandler:
macaroon.add_first_party_caveat( macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,) "client_redirect_url = %s" % (client_redirect_url,)
) )
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec() now = self._clock.time_msec()
expiry = now + duration_in_ms expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: def _verify_oidc_session_token(
self, session: str, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token. """Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver This verifies that a given session token was issued by this homeserver
@ -734,7 +762,7 @@ class OidcHandler:
state: The state the OIDC provider gave back state: The state the OIDC provider gave back
Returns: Returns:
The nonce and the client_redirect_url for this session The nonce, client_redirect_url, and ui_auth_session_id for this session
""" """
macaroon = pymacaroons.Macaroon.deserialize(session) macaroon = pymacaroons.Macaroon.deserialize(session)
@ -744,17 +772,27 @@ class OidcHandler:
v.satisfy_exact("state = %s" % (state,)) v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key) v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce` and `client_redirect_url` from the token # Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce") nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon( client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url" macaroon, "client_redirect_url"
) )
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token. """Extracts a caveat value from a macaroon token.
@ -773,7 +811,7 @@ class OidcHandler:
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix): if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :] return caveat.caveat_id[len(prefix) :]
raise Exception("No %s caveat in macaroon" % (key,)) raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool: def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < " prefix = "time < "

View File

@ -204,6 +204,7 @@ class PresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._presence_enabled = hs.config.use_presence
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -676,13 +677,14 @@ class PresenceHandler(BasePresenceHandler):
async def incoming_presence(self, origin, content): async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server. """Called when we receive a `m.presence` EDU from a remote server.
""" """
if not self._presence_enabled:
return
now = self.clock.time_msec() now = self.clock.time_msec()
updates = [] updates = []
for push in content.get("push", []): for push in content.get("push", []):
# A "push" contains a list of presence that we are probably interested # A "push" contains a list of presence that we are probably interested
# in. # in.
# TODO: Actually check if we're interested, rather than blindly
# accepting presence updates.
user_id = push.get("user_id", None) user_id = push.get("user_id", None)
if not user_id: if not user_id:
logger.info( logger.info(

View File

@ -132,7 +132,7 @@ class RegistrationHandler(BaseHandler):
def register_user( def register_user(
self, self,
localpart=None, localpart=None,
password=None, password_hash=None,
guest_access_token=None, guest_access_token=None,
make_guest=False, make_guest=False,
admin=False, admin=False,
@ -147,7 +147,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart: The local part of the user ID to register. If None, localpart: The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (unicode): The password to assign to this user so they can password_hash (str|None): The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
@ -164,11 +164,6 @@ class RegistrationHandler(BaseHandler):
yield self.check_registration_ratelimit(address) yield self.check_registration_ratelimit(address)
yield self.auth.check_auth_blocking(threepid=threepid) yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield defer.ensureDeferred(
self._auth_handler.hash(password)
)
if localpart is not None: if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)

View File

@ -17,15 +17,16 @@
import abc import abc
import logging import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
from synapse import types from synapse import types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import Collection, RoomID, UserID from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
@ -76,84 +77,84 @@ class RoomMemberHandler(object):
self.base_handler = BaseHandler(hs) self.base_handler = BaseHandler(hs)
@abc.abstractmethod @abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content): async def _remote_join(
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
"""Try and join a room that this server is not in """Try and join a room that this server is not in
Args: Args:
requester (Requester) requester
remote_room_hosts (list[str]): List of servers that can be used remote_room_hosts: List of servers that can be used to join via.
to join via. room_id: Room that we are trying to join
room_id (str): Room that we are trying to join user: User who is trying to join
user (UserID): User who is trying to join content: A dict that should be used as the content of the join event.
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _remote_reject_invite( async def _remote_reject_invite(
self, requester, remote_room_hosts, room_id, target, content self,
): requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Attempt to reject an invite for a room this server is not in. If we """Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected. fail to do so we locally mark the invite as rejected.
Args: Args:
requester (Requester) requester
remote_room_hosts (list[str]): List of servers to use to try and remote_room_hosts: List of servers to use to try and reject invite
reject invite room_id
room_id (str) target: The user rejecting the invite
target (UserID): The user rejecting the invite content: The content for the rejection event
content (dict): The content for the rejection event
Returns: Returns:
Deferred[dict]: A dictionary to be returned to the client, may A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected include event_id etc, or nothing if we locally rejected
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _user_joined_room(self, target, room_id): async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the """Notifies distributor on master process that the user has joined the
room. room.
Args: Args:
target (UserID) target
room_id (str) room_id
Returns:
Deferred|None
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _user_left_room(self, target, room_id): async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the """Notifies distributor on master process that the user has left the
room. room.
Args: Args:
target (UserID) target
room_id (str) room_id
Returns:
Deferred|None
""" """
raise NotImplementedError() raise NotImplementedError()
async def _local_membership_update( async def _local_membership_update(
self, self,
requester, requester: Requester,
target, target: UserID,
room_id, room_id: str,
membership, membership: str,
prev_event_ids: Collection[str], prev_event_ids: Collection[str],
txn_id=None, txn_id: Optional[str] = None,
ratelimit=True, ratelimit: bool = True,
content=None, content: Optional[dict] = None,
require_consent=True, require_consent: bool = True,
): ) -> EventBase:
user_id = target.to_string() user_id = target.to_string()
if content is None: if content is None:
@ -214,20 +215,18 @@ class RoomMemberHandler(object):
return event return event
@defer.inlineCallbacks async def copy_room_tags_and_direct_to_room(
def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): self, old_room_id, new_room_id, user_id
) -> None:
"""Copies the tags and direct room state from one room to another. """Copies the tags and direct room state from one room to another.
Args: Args:
old_room_id (str) old_room_id: The room ID of the old room.
new_room_id (str) new_room_id: The room ID of the new room.
user_id (str) user_id: The user's ID.
Returns:
Deferred[None]
""" """
# Retrieve user account data for predecessor room # Retrieve user account data for predecessor room
user_account_data, _ = yield self.store.get_account_data_for_user(user_id) user_account_data, _ = await self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable # Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {}) direct_rooms = user_account_data.get("m.direct", {})
@ -240,31 +239,31 @@ class RoomMemberHandler(object):
direct_rooms[key].append(new_room_id) direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data # Save back to user's m.direct account data
yield self.store.add_account_data_for_user( await self.store.add_account_data_for_user(
user_id, "m.direct", direct_rooms user_id, "m.direct", direct_rooms
) )
break break
# Copy room tags if applicable # Copy room tags if applicable
room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) room_tags = await self.store.get_tags_for_room(user_id, old_room_id)
# Copy each room tag to the new room # Copy each room tag to the new room
for tag, tag_content in room_tags.items(): for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
async def update_membership( async def update_membership(
self, self,
requester, requester: Requester,
target, target: UserID,
room_id, room_id: str,
action, action: str,
txn_id=None, txn_id: Optional[str] = None,
remote_room_hosts=None, remote_room_hosts: Optional[List[str]] = None,
third_party_signed=None, third_party_signed: Optional[dict] = None,
ratelimit=True, ratelimit: bool = True,
content=None, content: Optional[dict] = None,
require_consent=True, require_consent: bool = True,
): ) -> Union[EventBase, Optional[dict]]:
key = (room_id,) key = (room_id,)
with (await self.member_linearizer.queue(key)): with (await self.member_linearizer.queue(key)):
@ -285,17 +284,17 @@ class RoomMemberHandler(object):
async def _update_membership( async def _update_membership(
self, self,
requester, requester: Requester,
target, target: UserID,
room_id, room_id: str,
action, action: str,
txn_id=None, txn_id: Optional[str] = None,
remote_room_hosts=None, remote_room_hosts: Optional[List[str]] = None,
third_party_signed=None, third_party_signed: Optional[dict] = None,
ratelimit=True, ratelimit: bool = True,
content=None, content: Optional[dict] = None,
require_consent=True, require_consent: bool = True,
): ) -> Union[EventBase, Optional[dict]]:
content_specified = bool(content) content_specified = bool(content)
if content is None: if content is None:
content = {} content = {}
@ -469,12 +468,11 @@ class RoomMemberHandler(object):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
res = await self._remote_reject_invite( return await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content, requester, remote_room_hosts, room_id, target, content,
) )
return res
res = await self._local_membership_update( return await self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -485,10 +483,10 @@ class RoomMemberHandler(object):
content=content, content=content,
require_consent=require_consent, require_consent=require_consent,
) )
return res
@defer.inlineCallbacks async def transfer_room_state_on_room_upgrade(
def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): self, old_room_id: str, room_id: str
) -> None:
"""Upon our server becoming aware of an upgraded room, either by upgrading a room """Upon our server becoming aware of an upgraded room, either by upgrading a room
ourselves or joining one, we can transfer over information from the previous room. ourselves or joining one, we can transfer over information from the previous room.
@ -496,50 +494,44 @@ class RoomMemberHandler(object):
well as migrating the room directory state. well as migrating the room directory state.
Args: Args:
old_room_id (str): The ID of the old room old_room_id: The ID of the old room
room_id: The ID of the new room
room_id (str): The ID of the new room
Returns:
Deferred
""" """
logger.info("Transferring room state from %s to %s", old_room_id, room_id) logger.info("Transferring room state from %s to %s", old_room_id, room_id)
# Find all local users that were in the old room and copy over each user's state # Find all local users that were in the old room and copy over each user's state
users = yield self.store.get_users_in_room(old_room_id) users = await self.store.get_users_in_room(old_room_id)
yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
# Add new room to the room directory if the old room was there # Add new room to the room directory if the old room was there
# Remove old room from the room directory # Remove old room from the room directory
old_room = yield self.store.get_room(old_room_id) old_room = await self.store.get_room(old_room_id)
if old_room and old_room["is_public"]: if old_room and old_room["is_public"]:
yield self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(old_room_id, False)
yield self.store.set_room_is_public(room_id, True) await self.store.set_room_is_public(room_id, True)
# Transfer alias mappings in the room directory # Transfer alias mappings in the room directory
yield self.store.update_aliases_for_room(old_room_id, room_id) await self.store.update_aliases_for_room(old_room_id, room_id)
# Check if any groups we own contain the predecessor room # Check if any groups we own contain the predecessor room
local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids: for group_id in local_group_ids:
# Add new the new room to those groups # Add new the new room to those groups
yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
# Remove the old room from those groups # Remove the old room from those groups
yield self.store.remove_room_from_group(group_id, old_room_id) await self.store.remove_room_from_group(group_id, old_room_id)
@defer.inlineCallbacks async def copy_user_state_on_room_upgrade(
def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
) -> None:
"""Copy user-specific information when they join a new room when that new room is the """Copy user-specific information when they join a new room when that new room is the
result of a room upgrade result of a room upgrade
Args: Args:
old_room_id (str): The ID of upgraded room old_room_id: The ID of upgraded room
new_room_id (str): The ID of the new room new_room_id: The ID of the new room
user_ids (Iterable[str]): User IDs to copy state for user_ids: User IDs to copy state for
Returns:
Deferred
""" """
logger.debug( logger.debug(
@ -552,11 +544,11 @@ class RoomMemberHandler(object):
for user_id in user_ids: for user_id in user_ids:
try: try:
# It is an upgraded room. Copy over old tags # It is an upgraded room. Copy over old tags
yield self.copy_room_tags_and_direct_to_room( await self.copy_room_tags_and_direct_to_room(
old_room_id, new_room_id, user_id old_room_id, new_room_id, user_id
) )
# Copy over push rules # Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user( await self.store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id old_room_id, new_room_id, user_id
) )
except Exception: except Exception:
@ -569,17 +561,23 @@ class RoomMemberHandler(object):
) )
continue continue
async def send_membership_event(self, requester, event, context, ratelimit=True): async def send_membership_event(
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
):
""" """
Change the membership status of a user in a room. Change the membership status of a user in a room.
Args: Args:
requester (Requester): The local user who requested the membership requester: The local user who requested the membership
event. If None, certain checks, like whether this homeserver can event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped. act as the sender, will be skipped.
event (SynapseEvent): The membership event. event: The membership event.
context: The context of the event. context: The context of the event.
ratelimit (bool): Whether to rate limit this request. ratelimit: Whether to rate limit this request.
Raises: Raises:
SynapseError if there was a problem changing the membership. SynapseError if there was a problem changing the membership.
""" """
@ -639,8 +637,9 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id) await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks async def _can_guest_join(
def _can_guest_join(self, current_state_ids): self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
@ -648,7 +647,7 @@ class RoomMemberHandler(object):
if not guest_access_id: if not guest_access_id:
return False return False
guest_access = yield self.store.get_event(guest_access_id) guest_access = await self.store.get_event(guest_access_id)
return ( return (
guest_access guest_access
@ -657,13 +656,14 @@ class RoomMemberHandler(object):
and guest_access.content["guest_access"] == "can_join" and guest_access.content["guest_access"] == "can_join"
) )
@defer.inlineCallbacks async def lookup_room_alias(
def lookup_room_alias(self, room_alias): self, room_alias: RoomAlias
) -> Tuple[RoomID, List[str]]:
""" """
Get the room ID associated with a room alias. Get the room ID associated with a room alias.
Args: Args:
room_alias (RoomAlias): The alias to look up. room_alias: The alias to look up.
Returns: Returns:
A tuple of: A tuple of:
The room ID as a RoomID object. The room ID as a RoomID object.
@ -672,7 +672,7 @@ class RoomMemberHandler(object):
SynapseError if room alias could not be found. SynapseError if room alias could not be found.
""" """
directory_handler = self.directory_handler directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = await directory_handler.get_association(room_alias)
if not mapping: if not mapping:
raise SynapseError(404, "No such room alias") raise SynapseError(404, "No such room alias")
@ -687,25 +687,25 @@ class RoomMemberHandler(object):
return RoomID.from_string(room_id), servers return RoomID.from_string(room_id), servers
@defer.inlineCallbacks async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
def _get_inviter(self, user_id, room_id): invite = await self.store.get_invite_for_local_user_in_room(
invite = yield self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id user_id=user_id, room_id=room_id
) )
if invite: if invite:
return UserID.from_string(invite.sender) return UserID.from_string(invite.sender)
return None
async def do_3pid_invite( async def do_3pid_invite(
self, self,
room_id, room_id: str,
inviter, inviter: UserID,
medium, medium: str,
address, address: str,
id_server, id_server: str,
requester, requester: Requester,
txn_id, txn_id: Optional[str],
id_access_token=None, id_access_token: Optional[str] = None,
): ) -> None:
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
@ -754,15 +754,15 @@ class RoomMemberHandler(object):
async def _make_and_store_3pid_invite( async def _make_and_store_3pid_invite(
self, self,
requester, requester: Requester,
id_server, id_server: str,
medium, medium: str,
address, address: str,
room_id, room_id: str,
user, user: UserID,
txn_id, txn_id: Optional[str],
id_access_token=None, id_access_token: Optional[str] = None,
): ) -> None:
room_state = await self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
@ -836,8 +836,9 @@ class RoomMemberHandler(object):
txn_id=txn_id, txn_id=txn_id,
) )
@defer.inlineCallbacks async def _is_host_in_room(
def _is_host_in_room(self, current_state_ids): self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
# Have we just created the room, and is this about to be the very # Have we just created the room, and is this about to be the very
# first member event? # first member event?
create_event_id = current_state_ids.get(("m.room.create", "")) create_event_id = current_state_ids.get(("m.room.create", ""))
@ -850,7 +851,7 @@ class RoomMemberHandler(object):
continue continue
event_id = current_state_ids[(etype, state_key)] event_id = current_state_ids[(etype, state_key)]
event = yield self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)
if not event: if not event:
continue continue
@ -859,11 +860,10 @@ class RoomMemberHandler(object):
return False return False
@defer.inlineCallbacks async def _is_server_notice_room(self, room_id: str) -> bool:
def _is_server_notice_room(self, room_id):
if self._server_notices_mxid is None: if self._server_notices_mxid is None:
return False return False
user_ids = yield self.store.get_users_in_room(room_id) user_ids = await self.store.get_users_in_room(room_id)
return self._server_notices_mxid in user_ids return self._server_notices_mxid in user_ids
@ -875,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
async def _is_remote_room_too_complex(self, room_id, remote_room_hosts): async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
) -> Optional[bool]:
""" """
Check if complexity of a remote room is too great. Check if complexity of a remote room is too great.
Args: Args:
room_id (str) room_id
remote_room_hosts (list[str]) remote_room_hosts
Returns: bool of whether the complexity is too great, or None Returns: bool of whether the complexity is too great, or None
if unable to be fetched if unable to be fetched
@ -895,22 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity return complexity["v1"] > max_complexity
return None return None
@defer.inlineCallbacks async def _is_local_room_too_complex(self, room_id: str) -> bool:
def _is_local_room_too_complex(self, room_id):
""" """
Check if the complexity of a local room is too great. Check if the complexity of a local room is too great.
Args: Args:
room_id (str) room_id: The room ID to check for complexity.
Returns: bool
""" """
max_complexity = self.hs.config.limit_remote_rooms.complexity max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = yield self.store.get_room_complexity(room_id) complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity return complexity["v1"] > max_complexity
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): async def _remote_join(
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> None:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -969,19 +975,21 @@ class RoomMemberMasterHandler(RoomMemberHandler):
errcode=Codes.RESOURCE_LIMIT_EXCEEDED, errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
) )
@defer.inlineCallbacks async def _remote_reject_invite(
def _remote_reject_invite( self,
self, requester, remote_room_hosts, room_id, target, content requester: Requester,
): remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite
""" """
fed_handler = self.federation_handler fed_handler = self.federation_handler
try: try:
ret = yield defer.ensureDeferred( ret = await fed_handler.do_remotely_reject_invite(
fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content, remote_room_hosts, room_id, target.to_string(), content=content,
) )
)
return ret return ret
except Exception as e: except Exception as e:
# if we were unable to reject the exception, just mark # if we were unable to reject the exception, just mark
@ -992,24 +1000,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(target.to_string(), room_id) await self.store.locally_reject_invite(target.to_string(), room_id)
return {} return {}
def _user_joined_room(self, target, room_id): async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room
""" """
return defer.succeed(user_joined_room(self.distributor, target, room_id)) user_joined_room(self.distributor, target, room_id)
def _user_left_room(self, target, room_id): async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room
""" """
return defer.succeed(user_left_room(self.distributor, target, room_id)) user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks async def forget(self, user: UserID, room_id: str) -> None:
def forget(self, user, room_id):
user_id = user.to_string() user_id = user.to_string()
member = yield self.state_handler.get_current_state( member = await self.state_handler.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id room_id=room_id, event_type=EventTypes.Member, state_key=user_id
) )
membership = member.membership if member else None membership = member.membership if member else None
@ -1021,4 +1028,4 @@ class RoomMemberMasterHandler(RoomMemberHandler):
raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
if membership: if membership:
yield self.store.forget(user_id, room_id) await self.store.forget(user_id, room_id)

View File

@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler from synapse.handlers.room_member import RoomMemberHandler
@ -24,6 +23,7 @@ from synapse.replication.http.membership import (
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
) )
from synapse.types import Requester, UserID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,14 +36,20 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
self._remote_reject_client = ReplRejectInvite.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs)
@defer.inlineCallbacks async def _remote_join(
def _remote_join(self, requester, remote_room_hosts, room_id, user, content): self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
ret = yield self._remote_join_client( ret = await self._remote_join_client(
requester=requester, requester=requester,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
room_id=room_id, room_id=room_id,
@ -51,16 +57,21 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content, content=content,
) )
yield self._user_joined_room(user, room_id) await self._user_joined_room(user, room_id)
return ret return ret
def _remote_reject_invite( async def _remote_reject_invite(
self, requester, remote_room_hosts, room_id, target, content self,
): requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite
""" """
return self._remote_reject_client( return await self._remote_reject_client(
requester=requester, requester=requester,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
room_id=room_id, room_id=room_id,
@ -68,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content, content=content,
) )
def _user_joined_room(self, target, room_id): async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room
""" """
return self._notify_change_client( await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="joined" user_id=target.to_string(), room_id=room_id, change="joined"
) )
def _user_left_room(self, target, room_id): async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room
""" """
return self._notify_change_client( await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left" user_id=target.to_string(), room_id=room_id, change="left"
) )

View File

@ -19,7 +19,7 @@ import random
import sys import sys
from io import BytesIO from io import BytesIO
from six import PY3, raise_from, string_types from six import raise_from, string_types
from six.moves import urllib from six.moves import urllib
import attr import attr
@ -70,11 +70,7 @@ incoming_responses_counter = Counter(
MAX_LONG_RETRIES = 10 MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3 MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
_next_id = 1 _next_id = 1

View File

@ -20,8 +20,6 @@ import time
from functools import wraps from functools import wraps
from inspect import getcallargs from inspect import getcallargs
from six import PY3
_TIME_FUNC_ID = 0 _TIME_FUNC_ID = 0
@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name) logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
if PY3:
lineno = f.__code__.co_firstlineno lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename pathname = f.__code__.co_filename
else:
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
record = logging.LogRecord( record = logging.LogRecord(
name=name, name=name,

View File

@ -15,8 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
import six
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -28,9 +26,6 @@ from synapse.push import PusherConfigException
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
if six.PY3:
long = int
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
http_push_processed_counter = Counter( http_push_processed_counter = Counter(
@ -318,7 +313,7 @@ class HttpPusher(object):
{ {
"app_id": self.app_id, "app_id": self.app_id,
"pushkey": self.pushkey, "pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000), "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url, "data": self.data_minus_url,
} }
], ],
@ -347,7 +342,7 @@ class HttpPusher(object):
{ {
"app_id": self.app_id, "app_id": self.app_id,
"pushkey": self.pushkey, "pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000), "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url, "data": self.data_minus_url,
"tweaks": tweaks, "tweaks": tweaks,
} }
@ -409,7 +404,7 @@ class HttpPusher(object):
{ {
"app_id": self.app_id, "app_id": self.app_id,
"pushkey": self.pushkey, "pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000), "pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url, "data": self.data_minus_url,
} }
], ],

View File

@ -34,8 +34,11 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs): def register_servlets(self, hs):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
membership.register_servlets(hs, self)
login.register_servlets(hs, self) login.register_servlets(hs, self)
register.register_servlets(hs, self) register.register_servlets(hs, self)
devices.register_servlets(hs, self) devices.register_servlets(hs, self)

View File

@ -16,8 +16,6 @@
import logging import logging
from typing import Optional from typing import Optional
import six
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -26,13 +24,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def __func__(inp):
if six.PY3:
return inp
else:
return inp.__func__
class BaseSlavedStore(CacheInvalidationWorkerStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)

View File

@ -15,11 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import ( from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
super(SlavedEventStore, self).__init__(database, db_conn, hs) super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
@ -92,81 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token() return self._backfill_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
-token,
row.event_id,
row.room_id,
row.type,
row.state_key,
row.redacts,
row.relates_to,
backfilled=True,
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
data = row.data
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token,
data.event_id,
data.room_id,
data.type,
data.state_key,
data.redacts,
data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
def invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))

View File

@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -27,14 +27,14 @@ class SlavedPresenceStore(BaseSlavedStore):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs) super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token() "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
) )
_get_active_presence = __func__(DataStore._get_active_presence) _get_active_presence = DataStore._get_active_presence
take_presence_startup_info = __func__(DataStore.take_presence_startup_info) take_presence_startup_info = DataStore.take_presence_startup_info
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]

View File

@ -15,19 +15,11 @@
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from synapse.storage.database import Database
from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, database: Database, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
def get_push_rules_stream_token(self): def get_push_rules_stream_token(self):
return ( return (
self._push_rules_stream_id_gen.get_current_token(), self._push_rules_stream_id_gen.get_current_token(),

View File

@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import heapq
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Tuple,
TypeVar,
)
import attr import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# the number of rows to request from an update_function. # the number of rows to request from an update_function.
@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question. # just a row from a database query, though this is dependent on the stream in question.
# #
StreamRow = Tuple StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(), # The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc. # get_updates_since, etc.
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
""" """
AccountDataStreamRow = namedtuple( AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str "AccountDataStream",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
) )
NAME = "account_data" NAME = "account_data"
ROW_TYPE = AccountDataStreamRow ROW_TYPE = AccountDataStreamRow
def __init__(self, hs): def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id), current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function), self._update_function,
) )
async def _update_function(self, from_token, to_token, limit): async def _update_function(
global_results, room_results = await self.store.get_all_updated_account_data( self, instance_name: str, from_token: int, to_token: int, limit: int
from_token, from_token, to_token, limit ) -> StreamUpdateResult:
limited = False
global_results = await self.store.get_updated_global_account_data(
from_token, to_token, limit
) )
results = list(room_results) # if the global results hit the limit, we'll need to limit the room results to
results.extend( # the same stream token.
(stream_id, user_id, None, account_data_type) if len(global_results) >= limit:
to_token = global_results[-1][0]
limited = True
room_results = await self.store.get_updated_room_account_data(
from_token, to_token, limit
)
# likewise, if the room results hit the limit, limit the global results to
# the same stream token.
if len(room_results) >= limit:
to_token = room_results[-1][0]
limited = True
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
(stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results for stream_id, user_id, account_data_type in global_results
if stream_id <= to_token
) )
return results # we know that the room_results are already limited to `to_token` so no need
# for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(room_rows, global_rows))
return updates, to_token, limited
class GroupServerStream(Stream): class GroupServerStream(Stream):

View File

@ -243,11 +243,11 @@ class UserRestServletV2(RestServlet):
else: # create user else: # create user
password = body.get("password") password = body.get("password")
if password is not None and ( password_hash = None
not isinstance(body["password"], text_type) if password is not None:
or len(body["password"]) > 512 if not isinstance(password, text_type) or len(password) > 512:
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password_hash = await self.auth_handler.hash(password)
admin = body.get("admin", None) admin = body.get("admin", None)
user_type = body.get("user_type", None) user_type = body.get("user_type", None)
@ -259,7 +259,7 @@ class UserRestServletV2(RestServlet):
user_id = await self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
localpart=target_user.localpart, localpart=target_user.localpart,
password=password, password_hash=password_hash,
admin=bool(admin), admin=bool(admin),
default_display_name=displayname, default_display_name=displayname,
user_type=user_type, user_type=user_type,
@ -298,7 +298,7 @@ class UserRegisterServlet(RestServlet):
NONCE_TIMEOUT = 60 NONCE_TIMEOUT = 60
def __init__(self, hs): def __init__(self, hs):
self.handlers = hs.get_handlers() self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self.nonces = {} self.nonces = {}
self.hs = hs self.hs = hs
@ -362,16 +362,16 @@ class UserRegisterServlet(RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON 400, "password must be specified", errcode=Codes.BAD_JSON
) )
else: else:
if ( password = body["password"]
not isinstance(body["password"], text_type) if not isinstance(password, text_type) or len(password) > 512:
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8") password_bytes = password.encode("utf-8")
if b"\x00" in password: if b"\x00" in password_bytes:
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password_hash = await self.auth_handler.hash(password)
admin = body.get("admin", None) admin = body.get("admin", None)
user_type = body.get("user_type", None) user_type = body.get("user_type", None)
@ -388,7 +388,7 @@ class UserRegisterServlet(RestServlet):
want_mac_builder.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac_builder.update(username) want_mac_builder.update(username)
want_mac_builder.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac_builder.update(password) want_mac_builder.update(password_bytes)
want_mac_builder.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac_builder.update(b"admin" if admin else b"notadmin") want_mac_builder.update(b"admin" if admin else b"notadmin")
if user_type: if user_type:
@ -407,7 +407,7 @@ class UserRegisterServlet(RestServlet):
user_id = await register.registration_handler.register_user( user_id = await register.registration_handler.register_user(
localpart=body["username"].lower(), localpart=body["username"].lower(),
password=body["password"], password_hash=password_hash,
admin=bool(admin), admin=bool(admin),
user_type=user_type, user_type=user_type,
) )

View File

@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request: SynapseRequest): async def on_GET(self, request: SynapseRequest):
args = request.args args = request.args
if b"redirectUrl" not in args: if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth" return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0] client_redirect_url = args[b"redirectUrl"][0]
sso_url = self.get_sso_url(client_redirect_url) sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url) request.redirect(sso_url)
finish_request(request) finish_request(request)
def get_sso_url(self, client_redirect_url: bytes) -> bytes: async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
"""Get the URL to redirect to, to perform SSO auth """Get the URL to redirect to, to perform SSO auth
Args: Args:
request: The client request to redirect.
client_redirect_url: the URL that we should redirect the client_redirect_url: the URL that we should redirect the
client to when everything is done client to when everything is done
@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs): def __init__(self, hs):
self._cas_handler = hs.get_cas_handler() self._cas_handler = hs.get_cas_handler()
def get_sso_url(self, client_redirect_url: bytes) -> bytes: async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._cas_handler.get_redirect_url( return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url} {"redirectUrl": client_redirect_url}
).encode("ascii") ).encode("ascii")
@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs): def __init__(self, hs):
self._saml_handler = hs.get_saml_handler() self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url: bytes) -> bytes: async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url) return self._saml_handler.handle_redirect_request(client_redirect_url)
class OIDCRedirectServlet(RestServlet): class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow.""" """Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True) PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler() self._oidc_handler = hs.get_oidc_handler()
async def on_GET(self, request): async def get_sso_url(
args = request.args self, request: SynapseRequest, client_redirect_url: bytes
if b"redirectUrl" not in args: ) -> bytes:
return 400, "Redirect URL not specified for SSO auth" return await self._oidc_handler.handle_redirect_request(
client_redirect_url = args[b"redirectUrl"][0] request, client_redirect_url
await self._oidc_handler.handle_redirect_request(request, client_redirect_url) )
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
# SSO configuration. # SSO configuration.
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._cas_enabled = hs.config.cas_enabled self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled: if self._cas_enabled:
self._cas_handler = hs.get_cas_handler() self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
async def on_GET(self, request, stagetype): async def on_GET(self, request, stagetype):
session = parse_string(request, "session") session = parse_string(request, "session")
@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
) )
elif self._saml_enabled: elif self._saml_enabled:
client_redirect_url = "" client_redirect_url = b""
sso_redirect_url = self._saml_handler.handle_redirect_request( sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session client_redirect_url, session
) )
elif self._oidc_enabled:
client_redirect_url = b""
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
request, client_redirect_url, session
)
else: else:
raise SynapseError(400, "Homeserver not configured for SSO.") raise SynapseError(400, "Homeserver not configured for SSO.")

View File

@ -426,12 +426,16 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
if "password" in body: if "password" in body:
if ( password = body.pop("password")
not isinstance(body["password"], string_types) if not isinstance(password, string_types) or len(password) > 512:
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(body["password"]) self.password_policy_handler.validate_password(password)
# If the password is valid, hash it and store it back on the request.
# This ensures the hashed password is handled everywhere.
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)
desired_username = None desired_username = None
if "username" in body: if "username" in body:
@ -484,7 +488,7 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
if "initial_device_display_name" in body and "password" not in body: if "initial_device_display_name" in body and "password_hash" not in body:
# ignore 'initial_device_display_name' if sent without # ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent # a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out # the 'initial_device_display_name' param alone, wiping out
@ -546,11 +550,11 @@ class RegisterRestServlet(RestServlet):
registered = False registered = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password"]) assert_params_in_dict(params, ["password_hash"])
desired_username = params.get("username", None) desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None) new_password_hash = params.get("password_hash", None)
if desired_username is not None: if desired_username is not None:
desired_username = desired_username.lower() desired_username = desired_username.lower()
@ -583,7 +587,7 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password=new_password, password_hash=new_password_hash,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
threepid=threepid, threepid=threepid,
address=client_addr, address=client_addr,

View File

@ -17,7 +17,6 @@
import logging import logging
import os import os
from six import PY3
from six.moves import urllib from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
@ -324,7 +323,6 @@ def get_filename_from_headers(headers):
upload_name_utf8 = upload_name_utf8[7:] upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8 # We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted. # bytes are %-quoted.
if PY3:
try: try:
# Once it is decoded, we can then unquote the %-encoded # Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string. # parts strictly into a unicode string.
@ -334,13 +332,6 @@ def get_filename_from_headers(headers):
except UnicodeDecodeError: except UnicodeDecodeError:
# Incorrect UTF-8. # Incorrect UTF-8.
pass pass
else:
# On Python 2, we first unquote the %-encoded parts and then
# decode it strictly using UTF-8.
try:
upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
except UnicodeDecodeError:
pass
# If there isn't check for an ascii name. # If there isn't check for an ascii name.
if not upload_name: if not upload_name:

View File

@ -19,9 +19,6 @@ import random
from abc import ABCMeta from abc import ABCMeta
from typing import Any, Optional from typing import Any, Optional
from six import PY2
from six.moves import builtins
from canonicaljson import json from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
@ -103,11 +100,6 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview): if isinstance(db_content, memoryview):
db_content = db_content.tobytes() db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we # Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out. # consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)): if isinstance(db_content, (bytes, bytearray)):

View File

@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
ChainedIdGenerator,
IdGenerator, IdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
@ -125,19 +124,6 @@ class DataStore(
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
@ -164,9 +150,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
self._pushers_id_gen = StreamIdGenerator( self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )

View File

@ -16,6 +16,7 @@
import abc import abc
import logging import logging
from typing import List, Tuple
from canonicaljson import json from canonicaljson import json
@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
) )
def get_all_updated_account_data( async def get_updated_global_account_data(
self, last_global_id, last_room_id, current_id, limit self, last_id: int, current_id: int, limit: int
): ) -> List[Tuple[int, str, str]]:
"""Get all the client account_data that has changed on the server """Get the global account_data that has changed, for the account_data stream
Args:
last_global_id(int): The position to fetch from for top level data
last_room_id(int): The position to fetch from for per room data
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
def get_updated_account_data_txn(txn): Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
and type string.
"""
if last_id == current_id:
return []
def get_updated_global_account_data_txn(txn):
sql = ( sql = (
"SELECT stream_id, user_id, account_data_type" "SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?" " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?" " ORDER BY stream_id ASC LIMIT ?"
) )
txn.execute(sql, (last_global_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
global_results = txn.fetchall() return txn.fetchall()
return await self.db.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)
async def get_updated_room_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
room_id string and type string.
"""
if last_id == current_id:
return []
def get_updated_room_account_data_txn(txn):
sql = ( sql = (
"SELECT stream_id, user_id, room_id, account_data_type" "SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?" " ORDER BY stream_id ASC LIMIT ?"
) )
txn.execute(sql, (last_room_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
room_results = txn.fetchall() return txn.fetchall()
return global_results, room_results
return self.db.runInteraction( return await self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn "get_updated_room_account_data", get_updated_room_account_data_txn
) )
def get_updated_account_data_for_user(self, user_id, stream_id): def get_updated_account_data_for_user(self, user_id, stream_id):

View File

@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache): def _make_exclusive_regex(services_cache):
# We precompie a regex constructed from all the regexes that the AS's # We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users. # have registered for exclusive users.
exclusive_user_regexes = [ exclusive_user_regexes = [
regex.pattern regex.pattern
for service in services_cache for service in services_cache
for regex in service.get_exlusive_user_regexes() for regex in service.get_exclusive_user_regexes()
] ]
if exclusive_user_regexes: if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)

View File

@ -16,8 +16,13 @@
import itertools import itertools
import logging import logging
from typing import Any, Iterable, Optional from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
) )
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "caches": if stream_name == "events":
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
for row in rows:
self._invalidate_caches_for_event(
-token,
row.event_id,
row.room_id,
row.type,
row.state_key,
row.redacts,
row.relates_to,
backfilled=True,
)
elif stream_name == "caches":
if self._cache_id_gen: if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token) self._cache_id_gen.advance(instance_name, token)
@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row):
data = row.data
if row.type == EventsStreamEventRow.TypeId:
self._invalidate_caches_for_event(
token,
data.event_id,
data.room_id,
data.type,
data.state_key,
data.redacts,
data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
This should only be used to invalidate caches where slaves won't
otherwise know from other replication streams that the cache should
be invalidated.
"""
cache_func = getattr(self, cache_name, None)
if not cache_func:
return
cache_func.invalidate(keys)
await self.db.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.

View File

@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs) super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker_app is None:
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache( self._get_event_cache = Cache(
"*getEvent*", "*getEvent*",
keylen=3, keylen=3,
@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id): def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event. """Get received_ts (when it was persisted) for the event.

View File

@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group", desc="get_invited_users_in_group",
) )
def get_rooms_in_group(self, group_id, include_private=False): def get_rooms_in_group(self, group_id: str, include_private: bool = False):
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
Args:
group_id: The ID of the group to query for rooms
include_private: Whether to return private rooms in results
Returns:
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
"is_public": False # Whether this is a public room or not
}
"""
# TODO: Pagination # TODO: Pagination
keyvalues = {"group_id": group_id} def _get_rooms_in_group_txn(txn):
if not include_private: sql = """
keyvalues["is_public"] = True SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
return self.db.simple_select_list( AND room_id IN (
table="group_rooms", SELECT group_rooms.room_id FROM group_rooms
keyvalues=keyvalues, LEFT JOIN room_stats_current ON
retcols=("room_id", "is_public"), group_rooms.room_id = room_stats_current.room_id
desc="get_rooms_in_group", AND joined_members > 0
AND local_users_in_room > 0
LEFT JOIN rooms ON
group_rooms.room_id = rooms.room_id
AND (room_version <> '') = ?
) )
"""
args = [group_id, False]
def get_rooms_for_summary_by_category(self, group_id, include_private=False): if not include_private:
sql += " AND is_public = ?"
args += [True]
txn.execute(sql, args)
return [
{"room_id": room_id, "is_public": is_public}
for room_id, is_public in txn
]
return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
):
"""Get the rooms and categories that should be included in a summary request """Get the rooms and categories that should be included in a summary request
Returns ([rooms], [categories]) Args:
group_id: The ID of the group to query the summary for
include_private: Whether to return private rooms in results
Returns:
Deferred[Tuple[List, Dict]]: A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
* "is_public": bool, whether the room is public
* "category_id": str|None, the category ID if set, else None
* "order": int, the sort order of rooms
* A dictionary with the key:
* category_id (str): a dictionary with the keys:
* "is_public": bool, whether the category is public
* "profile": str, the category profile
* "order": int, the sort order of rooms in this category
""" """
def _get_rooms_for_summary_txn(txn): def _get_rooms_for_summary_txn(txn):
@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
SELECT room_id, is_public, category_id, room_order SELECT room_id, is_public, category_id, room_order
FROM group_summary_rooms FROM group_summary_rooms
WHERE group_id = ? WHERE group_id = ?
AND room_id IN (
SELECT group_rooms.room_id FROM group_rooms
LEFT JOIN room_stats_current ON
group_rooms.room_id = room_stats_current.room_id
AND joined_members > 0
AND local_users_in_room > 0
LEFT JOIN rooms ON
group_rooms.room_id = rooms.room_id
AND (room_version <> '') = ?
)
""" """
if not include_private: if not include_private:
sql += " AND is_public = ?" sql += " AND is_public = ?"
txn.execute(sql, (group_id, True)) txn.execute(sql, (group_id, False, True))
else: else:
txn.execute(sql, (group_id,)) txn.execute(sql, (group_id, False))
rooms = [ rooms = [
{ {

View File

@ -17,8 +17,6 @@
import itertools import itertools
import logging import logging
import six
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -28,11 +26,7 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db.simple_update( return self.db.simple_update(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={ updatevalues={
"displayname": displayname, "displayname": displayname,
"avatar_url": avatar_url, "avatar_url": avatar_url,
"last_check": self._clock.time_msec(), "last_check": self._clock.time_msec(),

View File

@ -16,19 +16,23 @@
import abc import abc
import logging import logging
from typing import Union
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -64,6 +68,7 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore, ReceiptsWorkerStore,
PusherWorkerStore, PusherWorkerStore,
RoomMemberWorkerStore, RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore, SQLBaseStore,
): ):
"""This is an abstract base class where subclasses must implement """This is an abstract base class where subclasses must implement
@ -77,6 +82,15 @@ class PushRulesWorkerStore(
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
) # type: Union[ChainedIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
push_rules_prefill, push_rules_id = self.db.get_cache_dict( push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn, db_conn,
"push_rules_stream", "push_rules_stream",

View File

@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
""" """
txn.execute(sql, (room_id, Membership.JOIN)) txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn] return [r[0] for r in txn]
@cached(max_entries=100000) @cached(max_entries=100000)
def get_room_summary(self, room_id): def get_room_summary(self, room_id):
@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
res = {} res = {}
for count, membership in txn: for count, membership in txn:
summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) summary = res.setdefault(membership, MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so # we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent # heroes are consistent
@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
for user_id, membership, event_id in txn: for user_id, membership, event_id in txn:
summary = res[to_ascii(membership)] summary = res[membership]
# we will always have a summary for this membership type at this # we will always have a summary for this membership type at this
# point given the summary currently contains the counts. # point given the summary currently contains the counts.
members = summary.members members = summary.members
members.append((to_ascii(user_id), to_ascii(event_id))) members.append((user_id, event_id))
return res return res
@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id) ev_entry = event_map.get(event_id)
if ev_entry: if ev_entry:
if ev_entry.event.membership == Membership.JOIN: if ev_entry.event.membership == Membership.JOIN:
users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=to_ascii( display_name=ev_entry.event.content.get("displayname", None),
ev_entry.event.content.get("displayname", None) avatar_url=ev_entry.event.content.get("avatar_url", None),
),
avatar_url=to_ascii(
ev_entry.event.content.get("avatar_url", None)
),
) )
else: else:
missing_member_event_ids.append(event_id) missing_member_event_ids.append(event_id)
@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if event is not None and event.type == EventTypes.Member: if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if event.event_id in member_event_ids: if event.event_id in member_event_ids:
users_in_room[to_ascii(event.state_key)] = ProfileInfo( users_in_room[event.state_key] = ProfileInfo(
display_name=to_ascii(event.content.get("displayname", None)), display_name=event.content.get("displayname", None),
avatar_url=to_ascii(event.content.get("avatar_url", None)), avatar_url=event.content.get("avatar_url", None),
) )
return users_in_room return users_in_room

View File

@ -37,7 +37,55 @@ SearchEntry = namedtuple(
) )
class SearchBackgroundUpdateStore(SQLBaseStore): class SearchWorkerStore(SQLBaseStore):
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if not self.hs.config.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
(
entry.event_id,
entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
(entry.event_id, entry.room_id, entry.key, entry.value)
for entry in entries
)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
return num_rows return num_rows
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if not self.hs.config.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
(
entry.event_id,
entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
(entry.event_id, entry.room_id, entry.key, entry.value)
for entry in entries
)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore): class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):

View File

@ -29,7 +29,6 @@ from synapse.storage.database import Database
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
(room_id,), (room_id,),
) )
return { return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
return self.db.runInteraction( return self.db.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn "get_current_state_ids", _get_current_state_ids_txn

View File

@ -16,8 +16,6 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
import six
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
@ -27,11 +25,6 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview db_binary_type = memoryview
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.types import Collection from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -424,20 +423,14 @@ class Database(object):
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
logger.warning( logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
name,
exception_to_unicode(e),
i,
N,
) )
if i < N: if i < N:
i += 1 i += 1
try: try:
conn.rollback() conn.rollback()
except self.engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning( logger.warning("[TXN EROLL] {%s} %s", name, e1)
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
continue continue
raise raise
except self.engine.module.DatabaseError as e: except self.engine.module.DatabaseError as e:
@ -449,9 +442,7 @@ class Database(object):
conn.rollback() conn.rollback()
except self.engine.module.Error as e1: except self.engine.module.Error as e1:
logger.warning( logger.warning(
"[TXN EROLL] {%s} %s", "[TXN EROLL] {%s} %s", name, e1,
name,
exception_to_unicode(e1),
) )
continue continue
raise raise

View File

@ -166,6 +166,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column): def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator self.chained_generator = chained_generator
self._table = table
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column) self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
@ -204,6 +205,16 @@ class ChainedIdGenerator(object):
return self._current_max, self.chained_generator.get_current_token() return self._current_max, self.chained_generator.get_current_token()
def advance(self, token: int):
"""Stub implementation for advancing the token when receiving updates
over replication; raises an exception as this instance should be the
only source of updates.
"""
raise Exception(
"Attempted to advance token on source for table %r", self._table
)
class MultiWriterIdGenerator: class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers. """An ID generator that tracks a stream that can have multiple writers.

View File

@ -15,11 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from sys import intern
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
import six
from six.moves import intern
import attr import attr
from prometheus_client.core import Gauge from prometheus_client.core import Gauge
@ -154,9 +152,6 @@ def intern_string(string):
return None return None
try: try:
if six.PY2:
string = string.encode("ascii")
return intern(string) return intern(string)
except UnicodeEncodeError: except UnicodeEncodeError:
return string return string

View File

@ -65,5 +65,5 @@ def _handle_frozendict(obj):
) )
# A JSONEncoder which is capable of encoding frozendics without barfing # A JSONEncoder which is capable of encoding frozendicts without barfing
frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)

View File

@ -19,10 +19,6 @@ import re
import string import string
from collections import Iterable from collections import Iterable
import six
from six import PY2, PY3
from six.moves import range
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@ -47,8 +43,6 @@ def random_string_with_symbols(length):
def is_ascii(s): def is_ascii(s):
if PY3:
if isinstance(s, bytes): if isinstance(s, bytes):
try: try:
s.decode("ascii").encode("ascii") s.decode("ascii").encode("ascii")
@ -58,68 +52,6 @@ def is_ascii(s):
return False return False
return True return True
try:
s.encode("ascii")
except UnicodeEncodeError:
return False
except UnicodeDecodeError:
return False
else:
return True
def to_ascii(s):
"""Converts a string to ascii if it is ascii, otherwise leave it alone.
If given None then will return None.
"""
if PY3:
return s
if s is None:
return None
try:
return s.encode("ascii")
except UnicodeEncodeError:
return s
def exception_to_unicode(e):
"""Helper function to extract the text of an exception as a unicode string
Args:
e (Exception): exception to be stringified
Returns:
unicode
"""
# urgh, this is a mess. The basic problem here is that psycopg2 constructs its
# exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will
# then produce the raw byte sequence. Under Python 2, this will then cause another
# error if it gets mixed with a `unicode` object, as per
# https://github.com/matrix-org/synapse/issues/4252
# First of all, if we're under python3, everything is fine because it will sort this
# nonsense out for us.
if not PY2:
return str(e)
# otherwise let's have a stab at decoding the exception message. We'll circumvent
# Exception.__str__(), which would explode if someone raised Exception(u'non-ascii')
# and instead look at what is in the args member.
if len(e.args) == 0:
return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
return msg.decode("utf-8", errors="replace")
else:
return msg
def assert_valid_client_secret(client_secret): def assert_valid_client_secret(client_secret):
"""Validate that a given string matches the client_secret regex defined by the spec""" """Validate that a given string matches the client_secret regex defined by the spec"""

View File

@ -156,7 +156,7 @@ class PruneEventTestCase(unittest.TestCase):
"signatures": {}, "signatures": {},
"unsigned": {}, "unsigned": {},
}, },
room_version=RoomVersions.MSC2432_DEV, room_version=RoomVersions.V6,
) )

View File

@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from unittest import TestCase
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin from synapse.rest import admin
@ -207,3 +210,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event return join_event
class EventFromPduTestCase(TestCase):
def test_valid_json(self):
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
"type": EventTypes.Message,
"content": {"bool": True, "null": None, "int": 1, "str": "foobar"},
"room_id": "!room:test",
"sender": "@user:test",
"depth": 1,
"prev_events": [],
"auth_events": [],
"origin_server_ts": 1234,
},
RoomVersions.V6,
)
self.assertIsInstance(ev, EventBase)
def test_invalid_numbers(self):
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
2 ** 53,
1.0,
float("inf"),
float("-inf"),
float("nan"),
]:
with self.assertRaises(SynapseError):
event_from_pdu_json(
{
"type": EventTypes.Message,
"content": {"foo": value},
"room_id": "!room:test",
"sender": "@user:test",
"depth": 1,
"prev_events": [],
"auth_events": [],
"origin_server_ts": 1234,
},
RoomVersions.V6,
)
def test_invalid_nested(self):
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
{
"type": EventTypes.Message,
"content": {"foo": [{"bar": 2 ** 56}]},
"room_id": "!room:test",
"sender": "@user:test",
"depth": 1,
"prev_events": [],
"auth_events": [],
"origin_server_ts": 1234,
},
RoomVersions.V6,
)

View File

@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redirect_request(self): def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie.""" """The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie", "redirect", "finish"]) req = Mock(spec=["addCookie"])
yield defer.ensureDeferred( url = yield defer.ensureDeferred(
self.handler.handle_redirect_request(req, b"http://client/redirect") self.handler.handle_redirect_request(req, b"http://client/redirect")
) )
url = req.redirect.call_args[0][0]
url = urlparse(url) url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
nonce = "nonce" nonce = "nonce"
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
session = self.handler._generate_oidc_session_token( session = self.handler._generate_oidc_session_token(
state=state, nonce=nonce, client_redirect_url=client_redirect_url, state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
) )
request.getCookie.return_value = session request.getCookie.return_value = session
@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Mismatching session # Mismatching session
session = self.handler._generate_oidc_session_token( session = self.handler._generate_oidc_session_token(
state="state", nonce="nonce", client_redirect_url="http://client/redirect", state="state",
nonce="nonce",
client_redirect_url="http://client/redirect",
ui_auth_session_id=None,
) )
request.args = {} request.args = {}
request.args[b"state"] = [b"mismatching state"] request.args[b"state"] = [b"mismatching state"]

View File

@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from tests.server import FakeTransport
from ._base import BaseSlavedStoreTestCase from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:test" USER_ID = "@feeling:test"
@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# limit the replication rate # limit the replication rate
repl_transport = self._server_transport repl_transport = self._server_transport
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False repl_transport.autoflush = False
# build the join and message events and persist them in the same batch. # build the join and message events and persist them in the same batch.
@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message", type="m.room.message",
key=None, key=None,
internal={}, internal={},
state=None,
depth=None, depth=None,
prev_events=[], prev_events=[],
auth_events=[], auth_events=[],
@ -362,13 +363,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event = make_event_from_dict(event_dict, internal_metadata_dict=internal) event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1 self.event_id += 1
if state is not None:
state_ids = {key: e.event_id for key, e in state.items()}
context = EventContext.with_state(
state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
)
else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event)) context = self.get_success(state_handler.compute_event_context(event))

View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
AccountDataStream,
)
from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
"""Test replication with many room account data updates
"""
store = self.hs.get_datastore()
# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(
store.add_account_data_to_room("test_user", "test_room", update, {})
)
updates.append(update)
# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertEqual(row.room_id, "test_room")
(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.global")
self.assertIsNone(row.room_id)
self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self):
"""Test replication with many global account data updates
"""
store = self.hs.get_datastore()
# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(store.add_account_data_for_user("test_user", update, {}))
updates.append(update)
# also one per-room update
self.get_success(
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertIsNone(row.room_id)
(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.per_room")
self.assertEqual(row.room_id, "test_room")
self.assertEqual([], received_rows)

View File

@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata(self): def test_parse_rdata(self):
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line) cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand) assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events") self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863) self.assertEqual(cmd.token, 6287863)
@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata_batch(self): def test_parse_rdata_batch(self):
line = 'RDATA presence master batch ["@foo:example.com", "online"]' line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line) cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand) assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence") self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token) self.assertIsNone(cmd.token)

View File

@ -136,21 +136,18 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases # creator should be able to send aliases
event_auth.check( event_auth.check(
RoomVersions.MSC2432_DEV, RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
_alias_event(creator),
auth_events,
do_sig_check=False,
) )
# No particular checks are done on the state key. # No particular checks are done on the state key.
event_auth.check( event_auth.check(
RoomVersions.MSC2432_DEV, RoomVersions.V6,
_alias_event(creator, state_key=""), _alias_event(creator, state_key=""),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
) )
event_auth.check( event_auth.check(
RoomVersions.MSC2432_DEV, RoomVersions.V6,
_alias_event(creator, state_key="test.com"), _alias_event(creator, state_key="test.com"),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
@ -159,10 +156,7 @@ class EventAuthTestCase(unittest.TestCase):
# Per standard auth rules, the member must be in the room. # Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
event_auth.check( event_auth.check(
RoomVersions.MSC2432_DEV, RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
_alias_event(other),
auth_events,
do_sig_check=False,
) )
def test_msc2209(self): def test_msc2209(self):
@ -192,7 +186,7 @@ class EventAuthTestCase(unittest.TestCase):
# But an MSC2209 room rejects this change. # But an MSC2209 room rejects this change.
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
event_auth.check( event_auth.check(
RoomVersions.MSC2209_DEV, RoomVersions.V6,
_power_levels_event(pleb, {"notifications": {"room": 100}}), _power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,

View File

@ -180,6 +180,7 @@ commands = mypy \
synapse/api \ synapse/api \
synapse/appservice \ synapse/appservice \
synapse/config \ synapse/config \
synapse/event_auth.py \
synapse/events/spamcheck.py \ synapse/events/spamcheck.py \
synapse/federation \ synapse/federation \
synapse/handlers/auth.py \ synapse/handlers/auth.py \
@ -187,6 +188,8 @@ commands = mypy \
synapse/handlers/directory.py \ synapse/handlers/directory.py \
synapse/handlers/oidc_handler.py \ synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/room_member.py \
synapse/handlers/room_member_worker.py \
synapse/handlers/saml_handler.py \ synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \
@ -204,7 +207,7 @@ commands = mypy \
synapse/storage/util \ synapse/storage/util \
synapse/streams \ synapse/streams \
synapse/util/caches/stream_change_cache.py \ synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \ tests/replication \
tests/test_utils \ tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \ tests/rest/client/v2_alpha/test_auth.py \
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py