Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
a2b6ee7b00
|
@ -0,0 +1 @@
|
||||||
|
Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)).
|
|
@ -0,0 +1 @@
|
||||||
|
Send invite push notifications for invite over federation.
|
|
@ -0,0 +1 @@
|
||||||
|
Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -0,0 +1 @@
|
||||||
|
Port push rules to using Rust.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix unstable MSC3882 endpoint being incorrectly available on stable API versions.
|
|
@ -0,0 +1 @@
|
||||||
|
Only pull relevant backfill points from the database based on the current depth and limit (instead of all) every time we want to `/backfill`.
|
|
@ -0,0 +1 @@
|
||||||
|
Improve backfill robustness by trying more servers when we get a `4xx` error back.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug introduced in 1.66 where some required fields in the pushrules sent to clients were not present anymore. Contributed by Nico.
|
|
@ -0,0 +1 @@
|
||||||
|
Faster remote room joins: correctly handle remote device list updates during a partial join.
|
|
@ -0,0 +1 @@
|
||||||
|
Update an innaccurate comment in Synapse's upsert database helper.
|
|
@ -0,0 +1 @@
|
||||||
|
Add instruction to contributing guide for running unit tests in parallel. Contributed by @ashfame.
|
|
@ -0,0 +1 @@
|
||||||
|
Update the man page for the `hash_password` script to correct the default number of bcrypt rounds performed.
|
|
@ -0,0 +1 @@
|
||||||
|
Clarify that the `auto_join_rooms` config option can also be used with Space aliases.
|
|
@ -0,0 +1 @@
|
||||||
|
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
||||||
|
Correctly handle sending local device list updates to remote servers during a partial join.
|
|
@ -0,0 +1 @@
|
||||||
|
Exponentially backoff from backfilling the same event over and over.
|
|
@ -0,0 +1 @@
|
||||||
|
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
||||||
|
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
||||||
|
Add cache invalidation across workers to module API.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug introduced in v1.68.0 where Synapse would require `setuptools_rust` at runtime, even though the package is only required at build time.
|
|
@ -0,0 +1 @@
|
||||||
|
Ask mail servers receiving emails from Synapse to not send automatic reply (e.g. out-of-office responses).
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a performance regression in the `get_users_in_room` database query. Introduced in v1.67.0.
|
|
@ -0,0 +1 @@
|
||||||
|
Speed up calculating push actions in large rooms.
|
|
@ -0,0 +1 @@
|
||||||
|
Add some cross references to worker documentation.
|
|
@ -10,7 +10,7 @@
|
||||||
.P
|
.P
|
||||||
\fBhash_password\fR takes a password as an parameter either on the command line or the \fBSTDIN\fR if not supplied\.
|
\fBhash_password\fR takes a password as an parameter either on the command line or the \fBSTDIN\fR if not supplied\.
|
||||||
.P
|
.P
|
||||||
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB10\fR\.
|
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB12\fR\.
|
||||||
.P
|
.P
|
||||||
The hashed password is written on the \fBSTDOUT\fR\.
|
The hashed password is written on the \fBSTDOUT\fR\.
|
||||||
.SH "FILES"
|
.SH "FILES"
|
||||||
|
|
|
@ -167,6 +167,12 @@ was broken. They are slower than the linters but will typically catch more error
|
||||||
poetry run trial tests
|
poetry run trial tests
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can run unit tests in parallel by specifying `-jX` argument to `trial` where `X` is the number of parallel runners you want. To use 4 cpu cores, you would run them like:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
poetry run trial -j4 tests
|
||||||
|
```
|
||||||
|
|
||||||
If you wish to only run *some* unit tests, you may specify
|
If you wish to only run *some* unit tests, you may specify
|
||||||
another module instead of `tests` - or a test class or a method:
|
another module instead of `tests` - or a test class or a method:
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
worker_app: synapse.app.media_repository
|
||||||
|
worker_name: media_worker
|
||||||
|
|
||||||
|
# The replication listener on the main synapse process.
|
||||||
|
worker_replication_host: 127.0.0.1
|
||||||
|
worker_replication_http_port: 9093
|
||||||
|
|
||||||
|
worker_listeners:
|
||||||
|
- type: http
|
||||||
|
port: 8085
|
||||||
|
resources:
|
||||||
|
- names: [media]
|
||||||
|
|
||||||
|
worker_log_config: /etc/matrix-synapse/media-worker-log.yaml
|
|
@ -88,6 +88,18 @@ process, for example:
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Upgrading to v1.69.0
|
||||||
|
|
||||||
|
## Changes to the receipts replication streams
|
||||||
|
|
||||||
|
Synapse now includes information indicating if a receipt applies to a thread when
|
||||||
|
replicating it to other workers. This is a forwards- and backwards-incompatible
|
||||||
|
change: v1.68 and workers cannot process receipts replicated by v1.69 workers, and
|
||||||
|
vice versa.
|
||||||
|
|
||||||
|
Once all workers are upgraded to v1.69 (or downgraded to v1.68), receipts
|
||||||
|
replication will resume as normal.
|
||||||
|
|
||||||
# Upgrading to v1.68.0
|
# Upgrading to v1.68.0
|
||||||
|
|
||||||
Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0.
|
Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0.
|
||||||
|
|
|
@ -2229,6 +2229,9 @@ homeserver. If the room already exists, make certain it is a publicly joinable
|
||||||
room, i.e. the join rule of the room must be set to 'public'. You can find more options
|
room, i.e. the join rule of the room must be set to 'public'. You can find more options
|
||||||
relating to auto-joining rooms below.
|
relating to auto-joining rooms below.
|
||||||
|
|
||||||
|
As Spaces are just rooms under the hood, Space aliases may also be
|
||||||
|
used.
|
||||||
|
|
||||||
Example configuration:
|
Example configuration:
|
||||||
```yaml
|
```yaml
|
||||||
auto_join_rooms:
|
auto_join_rooms:
|
||||||
|
@ -2240,7 +2243,7 @@ auto_join_rooms:
|
||||||
|
|
||||||
Where `auto_join_rooms` are specified, setting this flag ensures that
|
Where `auto_join_rooms` are specified, setting this flag ensures that
|
||||||
the rooms exist by creating them when the first user on the
|
the rooms exist by creating them when the first user on the
|
||||||
homeserver registers.
|
homeserver registers. This option will not create Spaces.
|
||||||
|
|
||||||
By default the auto-created rooms are publicly joinable from any federated
|
By default the auto-created rooms are publicly joinable from any federated
|
||||||
server. Use the `autocreate_auto_join_rooms_federated` and
|
server. Use the `autocreate_auto_join_rooms_federated` and
|
||||||
|
@ -2258,7 +2261,7 @@ autocreate_auto_join_rooms: false
|
||||||
---
|
---
|
||||||
### `autocreate_auto_join_rooms_federated`
|
### `autocreate_auto_join_rooms_federated`
|
||||||
|
|
||||||
Whether the rooms listen in `auto_join_rooms` that are auto-created are available
|
Whether the rooms listed in `auto_join_rooms` that are auto-created are available
|
||||||
via federation. Only has an effect if `autocreate_auto_join_rooms` is true.
|
via federation. Only has an effect if `autocreate_auto_join_rooms` is true.
|
||||||
|
|
||||||
Note that whether a room is federated cannot be modified after
|
Note that whether a room is federated cannot be modified after
|
||||||
|
|
|
@ -93,7 +93,6 @@ listener" for the main process; and secondly, you need to enable redis-based
|
||||||
replication. Optionally, a shared secret can be used to authenticate HTTP
|
replication. Optionally, a shared secret can be used to authenticate HTTP
|
||||||
traffic between workers. For example:
|
traffic between workers. For example:
|
||||||
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# extend the existing `listeners` section. This defines the ports that the
|
# extend the existing `listeners` section. This defines the ports that the
|
||||||
# main process will listen on.
|
# main process will listen on.
|
||||||
|
@ -129,7 +128,8 @@ In the config file for each worker, you must specify:
|
||||||
* The HTTP replication endpoint that it should talk to on the main synapse process
|
* The HTTP replication endpoint that it should talk to on the main synapse process
|
||||||
(`worker_replication_host` and `worker_replication_http_port`)
|
(`worker_replication_host` and `worker_replication_http_port`)
|
||||||
* If handling HTTP requests, a `worker_listeners` option with an `http`
|
* If handling HTTP requests, a `worker_listeners` option with an `http`
|
||||||
listener, in the same way as the `listeners` option in the shared config.
|
listener, in the same way as the [`listeners`](usage/configuration/config_documentation.md#listeners)
|
||||||
|
option in the shared config.
|
||||||
* If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for
|
* If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for
|
||||||
the main process (`worker_main_http_uri`).
|
the main process (`worker_main_http_uri`).
|
||||||
|
|
||||||
|
@ -285,8 +285,9 @@ For multiple workers not handling the SSO endpoints properly, see
|
||||||
[#7530](https://github.com/matrix-org/synapse/issues/7530) and
|
[#7530](https://github.com/matrix-org/synapse/issues/7530) and
|
||||||
[#9427](https://github.com/matrix-org/synapse/issues/9427).
|
[#9427](https://github.com/matrix-org/synapse/issues/9427).
|
||||||
|
|
||||||
Note that a HTTP listener with `client` and `federation` resources must be
|
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
|
||||||
configured in the `worker_listeners` option in the worker config.
|
with `client` and `federation` `resources` must be configured in the `worker_listeners`
|
||||||
|
option in the worker config.
|
||||||
|
|
||||||
#### Load balancing
|
#### Load balancing
|
||||||
|
|
||||||
|
@ -326,7 +327,8 @@ effects of bursts of events from that bridge on events sent by normal users.
|
||||||
Additionally, the writing of specific streams (such as events) can be moved off
|
Additionally, the writing of specific streams (such as events) can be moved off
|
||||||
of the main process to a particular worker.
|
of the main process to a particular worker.
|
||||||
|
|
||||||
To enable this, the worker must have a HTTP replication listener configured,
|
To enable this, the worker must have a
|
||||||
|
[HTTP `replication` listener](usage/configuration/config_documentation.md#listeners) configured,
|
||||||
have a `worker_name` and be listed in the `instance_map` config. The same worker
|
have a `worker_name` and be listed in the `instance_map` config. The same worker
|
||||||
can handle multiple streams, but unless otherwise documented, each stream can only
|
can handle multiple streams, but unless otherwise documented, each stream can only
|
||||||
have a single writer.
|
have a single writer.
|
||||||
|
@ -410,7 +412,7 @@ the stream writer for the `presence` stream:
|
||||||
There is also support for moving background tasks to a separate
|
There is also support for moving background tasks to a separate
|
||||||
worker. Background tasks are run periodically or started via replication. Exactly
|
worker. Background tasks are run periodically or started via replication. Exactly
|
||||||
which tasks are configured to run depends on your Synapse configuration (e.g. if
|
which tasks are configured to run depends on your Synapse configuration (e.g. if
|
||||||
stats is enabled).
|
stats is enabled). This worker doesn't handle any REST endpoints itself.
|
||||||
|
|
||||||
To enable this, the worker must have a `worker_name` and can be configured to run
|
To enable this, the worker must have a `worker_name` and can be configured to run
|
||||||
background tasks. For example, to move background tasks to a dedicated worker,
|
background tasks. For example, to move background tasks to a dedicated worker,
|
||||||
|
@ -457,8 +459,8 @@ worker application type.
|
||||||
#### Notifying Application Services
|
#### Notifying Application Services
|
||||||
|
|
||||||
You can designate one generic worker to send output traffic to Application Services.
|
You can designate one generic worker to send output traffic to Application Services.
|
||||||
|
Doesn't handle any REST endpoints itself, but you should specify its name in the
|
||||||
Specify its name in the shared configuration as follows:
|
shared configuration as follows:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
notify_appservices_from_worker: worker_name
|
notify_appservices_from_worker: worker_name
|
||||||
|
@ -536,16 +538,12 @@ file to stop the main synapse running background jobs related to managing the
|
||||||
media repository. Note that doing so will prevent the main process from being
|
media repository. Note that doing so will prevent the main process from being
|
||||||
able to handle the above endpoints.
|
able to handle the above endpoints.
|
||||||
|
|
||||||
In the `media_repository` worker configuration file, configure the http listener to
|
In the `media_repository` worker configuration file, configure the
|
||||||
|
[HTTP listener](usage/configuration/config_documentation.md#listeners) to
|
||||||
expose the `media` resource. For example:
|
expose the `media` resource. For example:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
worker_listeners:
|
{{#include systemd-with-workers/workers/media_worker.yaml}}
|
||||||
- type: http
|
|
||||||
port: 8085
|
|
||||||
resources:
|
|
||||||
- names:
|
|
||||||
- media
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that if running multiple media repositories they must be on the same server
|
Note that if running multiple media repositories they must be on the same server
|
||||||
|
|
|
@ -11,7 +11,9 @@ rust-version = "1.58.1"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "synapse"
|
name = "synapse"
|
||||||
crate-type = ["cdylib"]
|
# We generate a `cdylib` for Python and a standard `lib` for running
|
||||||
|
# tests/benchmarks.
|
||||||
|
crate-type = ["lib", "cdylib"]
|
||||||
|
|
||||||
[package.metadata.maturin]
|
[package.metadata.maturin]
|
||||||
# This is where we tell maturin where to place the built library.
|
# This is where we tell maturin where to place the built library.
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#![feature(test)]
|
||||||
|
use synapse::push::{
|
||||||
|
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
|
||||||
|
};
|
||||||
|
use test::Bencher;
|
||||||
|
|
||||||
|
extern crate test;
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_match_exact(b: &mut Bencher) {
|
||||||
|
let flattened_keys = [
|
||||||
|
("type".to_string(), "m.text".to_string()),
|
||||||
|
("room_id".to_string(), "!room:server".to_string()),
|
||||||
|
("content.body".to_string(), "test message".to_string()),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let eval = PushRuleEvaluator::py_new(
|
||||||
|
flattened_keys,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
Default::default(),
|
||||||
|
Default::default(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||||
|
EventMatchCondition {
|
||||||
|
key: "room_id".into(),
|
||||||
|
pattern: Some("!room:server".into()),
|
||||||
|
pattern_type: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||||
|
assert!(matched, "Didn't match");
|
||||||
|
|
||||||
|
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_match_word(b: &mut Bencher) {
|
||||||
|
let flattened_keys = [
|
||||||
|
("type".to_string(), "m.text".to_string()),
|
||||||
|
("room_id".to_string(), "!room:server".to_string()),
|
||||||
|
("content.body".to_string(), "test message".to_string()),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let eval = PushRuleEvaluator::py_new(
|
||||||
|
flattened_keys,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
Default::default(),
|
||||||
|
Default::default(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||||
|
EventMatchCondition {
|
||||||
|
key: "content.body".into(),
|
||||||
|
pattern: Some("test".into()),
|
||||||
|
pattern_type: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||||
|
assert!(matched, "Didn't match");
|
||||||
|
|
||||||
|
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_match_word_miss(b: &mut Bencher) {
|
||||||
|
let flattened_keys = [
|
||||||
|
("type".to_string(), "m.text".to_string()),
|
||||||
|
("room_id".to_string(), "!room:server".to_string()),
|
||||||
|
("content.body".to_string(), "test message".to_string()),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let eval = PushRuleEvaluator::py_new(
|
||||||
|
flattened_keys,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
Default::default(),
|
||||||
|
Default::default(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||||
|
EventMatchCondition {
|
||||||
|
key: "content.body".into(),
|
||||||
|
pattern: Some("foobar".into()),
|
||||||
|
pattern_type: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
|
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||||
|
assert!(!matched, "Didn't match");
|
||||||
|
|
||||||
|
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_eval_message(b: &mut Bencher) {
|
||||||
|
let flattened_keys = [
|
||||||
|
("type".to_string(), "m.text".to_string()),
|
||||||
|
("room_id".to_string(), "!room:server".to_string()),
|
||||||
|
("content.body".to_string(), "test message".to_string()),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let eval = PushRuleEvaluator::py_new(
|
||||||
|
flattened_keys,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
Default::default(),
|
||||||
|
Default::default(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let rules =
|
||||||
|
FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false);
|
||||||
|
|
||||||
|
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#![feature(test)]
|
||||||
|
|
||||||
|
use synapse::push::utils::{glob_to_regex, GlobMatchType};
|
||||||
|
use test::Bencher;
|
||||||
|
|
||||||
|
extern crate test;
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_whole(b: &mut Bencher) {
|
||||||
|
b.iter(|| glob_to_regex("test", GlobMatchType::Whole));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_word(b: &mut Bencher) {
|
||||||
|
b.iter(|| glob_to_regex("test", GlobMatchType::Word));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_whole_wildcard_run(b: &mut Bencher) {
|
||||||
|
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
fn bench_word_wildcard_run(b: &mut Bencher) {
|
||||||
|
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
|
||||||
|
}
|
|
@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
|
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
if entry.is_dir() {
|
if entry.is_dir() {
|
||||||
dirs.push(entry)
|
dirs.push(entry);
|
||||||
} else {
|
} else {
|
||||||
paths.push(entry.to_str().expect("valid rust paths").to_string());
|
paths.push(entry.to_str().expect("valid rust paths").to_string());
|
||||||
}
|
}
|
||||||
|
|
|
@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
|
||||||
priority_class: 1,
|
priority_class: 1,
|
||||||
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch {
|
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch {
|
||||||
rel_type: Cow::Borrowed("m.thread"),
|
rel_type: Cow::Borrowed("m.thread"),
|
||||||
|
event_type_pattern: None,
|
||||||
sender: None,
|
sender: None,
|
||||||
sender_type: Some(Cow::Borrowed("user_id")),
|
sender_type: Some(Cow::Borrowed("user_id")),
|
||||||
})]),
|
})]),
|
||||||
|
|
|
@ -0,0 +1,374 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
borrow::Cow,
|
||||||
|
collections::{BTreeMap, BTreeSet},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{Context, Error};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use log::warn;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
|
||||||
|
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
|
||||||
|
};
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
/// Used to parse the `is` clause in the room member count condition.
|
||||||
|
static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allows running a set of push rules against a particular event.
|
||||||
|
#[pyclass]
|
||||||
|
pub struct PushRuleEvaluator {
|
||||||
|
/// A mapping of "flattened" keys to string values in the event, e.g.
|
||||||
|
/// includes things like "type" and "content.msgtype".
|
||||||
|
flattened_keys: BTreeMap<String, String>,
|
||||||
|
|
||||||
|
/// The "content.body", if any.
|
||||||
|
body: String,
|
||||||
|
|
||||||
|
/// The number of users in the room.
|
||||||
|
room_member_count: u64,
|
||||||
|
|
||||||
|
/// The `notifications` section of the current power levels in the room.
|
||||||
|
notification_power_levels: BTreeMap<String, i64>,
|
||||||
|
|
||||||
|
/// The relations related to the event as a mapping from relation type to
|
||||||
|
/// set of sender/event type 2-tuples.
|
||||||
|
relations: BTreeMap<String, BTreeSet<(String, String)>>,
|
||||||
|
|
||||||
|
/// Is running "relation" conditions enabled?
|
||||||
|
relation_match_enabled: bool,
|
||||||
|
|
||||||
|
/// The power level of the sender of the event, or None if event is an
|
||||||
|
/// outlier.
|
||||||
|
sender_power_level: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl PushRuleEvaluator {
|
||||||
|
/// Create a new `PushRuleEvaluator`. See struct docstring for details.
|
||||||
|
#[new]
|
||||||
|
pub fn py_new(
|
||||||
|
flattened_keys: BTreeMap<String, String>,
|
||||||
|
room_member_count: u64,
|
||||||
|
sender_power_level: Option<i64>,
|
||||||
|
notification_power_levels: BTreeMap<String, i64>,
|
||||||
|
relations: BTreeMap<String, BTreeSet<(String, String)>>,
|
||||||
|
relation_match_enabled: bool,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
let body = flattened_keys
|
||||||
|
.get("content.body")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
Ok(PushRuleEvaluator {
|
||||||
|
flattened_keys,
|
||||||
|
body,
|
||||||
|
room_member_count,
|
||||||
|
notification_power_levels,
|
||||||
|
relations,
|
||||||
|
relation_match_enabled,
|
||||||
|
sender_power_level,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the evaluator with the given push rules, for the given user ID and
|
||||||
|
/// display name of the user.
|
||||||
|
///
|
||||||
|
/// Passing in None will skip evaluating rules matching user ID and display
|
||||||
|
/// name.
|
||||||
|
///
|
||||||
|
/// Returns the set of actions, if any, that match (filtering out any
|
||||||
|
/// `dont_notify` actions).
|
||||||
|
pub fn run(
|
||||||
|
&self,
|
||||||
|
push_rules: &FilteredPushRules,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
display_name: Option<&str>,
|
||||||
|
) -> Vec<Action> {
|
||||||
|
'outer: for (push_rule, enabled) in push_rules.iter() {
|
||||||
|
if !enabled {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for condition in push_rule.conditions.iter() {
|
||||||
|
match self.match_condition(condition, user_id, display_name) {
|
||||||
|
Ok(true) => {}
|
||||||
|
Ok(false) => continue 'outer,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Condition match failed {err}");
|
||||||
|
continue 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let actions = push_rule
|
||||||
|
.actions
|
||||||
|
.iter()
|
||||||
|
// Filter out "dont_notify" actions, as we don't store them.
|
||||||
|
.filter(|a| **a != Action::DontNotify)
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
return actions;
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the given condition matches.
|
||||||
|
fn matches(
|
||||||
|
&self,
|
||||||
|
condition: Condition,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
display_name: Option<&str>,
|
||||||
|
) -> bool {
|
||||||
|
match self.match_condition(&condition, user_id, display_name) {
|
||||||
|
Ok(true) => true,
|
||||||
|
Ok(false) => false,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Condition match failed {err}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PushRuleEvaluator {
|
||||||
|
/// Match a given `Condition` for a push rule.
|
||||||
|
pub fn match_condition(
|
||||||
|
&self,
|
||||||
|
condition: &Condition,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
display_name: Option<&str>,
|
||||||
|
) -> Result<bool, Error> {
|
||||||
|
let known_condition = match condition {
|
||||||
|
Condition::Known(known) => known,
|
||||||
|
Condition::Unknown(_) => {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = match known_condition {
|
||||||
|
KnownCondition::EventMatch(event_match) => {
|
||||||
|
self.match_event_match(event_match, user_id)?
|
||||||
|
}
|
||||||
|
KnownCondition::ContainsDisplayName => {
|
||||||
|
if let Some(dn) = display_name {
|
||||||
|
if !dn.is_empty() {
|
||||||
|
get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)?
|
||||||
|
} else {
|
||||||
|
// We specifically ignore empty display names, as otherwise
|
||||||
|
// they would always match.
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KnownCondition::RoomMemberCount { is } => {
|
||||||
|
if let Some(is) = is {
|
||||||
|
self.match_member_count(is)?
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KnownCondition::SenderNotificationPermission { key } => {
|
||||||
|
if let Some(sender_power_level) = &self.sender_power_level {
|
||||||
|
let required_level = self
|
||||||
|
.notification_power_levels
|
||||||
|
.get(key.as_ref())
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(50);
|
||||||
|
|
||||||
|
*sender_power_level >= required_level
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KnownCondition::RelationMatch {
|
||||||
|
rel_type,
|
||||||
|
event_type_pattern,
|
||||||
|
sender,
|
||||||
|
sender_type,
|
||||||
|
} => {
|
||||||
|
self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluates a relation condition.
|
||||||
|
fn match_relations(
|
||||||
|
&self,
|
||||||
|
rel_type: &str,
|
||||||
|
sender: &Option<Cow<str>>,
|
||||||
|
sender_type: &Option<Cow<str>>,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
event_type_pattern: &Option<Cow<str>>,
|
||||||
|
) -> Result<bool, Error> {
|
||||||
|
// First check if relation matching is enabled...
|
||||||
|
if !self.relation_match_enabled {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ... and if there are any relations to match against.
|
||||||
|
let relations = if let Some(relations) = self.relations.get(rel_type) {
|
||||||
|
relations
|
||||||
|
} else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the sender pattern from the condition
|
||||||
|
let sender_pattern = if let Some(sender) = sender {
|
||||||
|
Some(sender.as_ref())
|
||||||
|
} else if let Some(sender_type) = sender_type {
|
||||||
|
if sender_type == "user_id" {
|
||||||
|
if let Some(user_id) = user_id {
|
||||||
|
Some(user_id)
|
||||||
|
} else {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Unrecognized sender_type: {sender_type}");
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern {
|
||||||
|
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern {
|
||||||
|
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
for (relation_sender, event_type) in relations {
|
||||||
|
if let Some(pattern) = &mut sender_compiled_pattern {
|
||||||
|
if !pattern.is_match(relation_sender)? {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pattern) = &mut type_compiled_pattern {
|
||||||
|
if !pattern.is_match(event_type)? {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluates a `event_match` condition.
|
||||||
|
fn match_event_match(
|
||||||
|
&self,
|
||||||
|
event_match: &EventMatchCondition,
|
||||||
|
user_id: Option<&str>,
|
||||||
|
) -> Result<bool, Error> {
|
||||||
|
let pattern = if let Some(pattern) = &event_match.pattern {
|
||||||
|
pattern
|
||||||
|
} else if let Some(pattern_type) = &event_match.pattern_type {
|
||||||
|
// The `pattern_type` can either be "user_id" or "user_localpart",
|
||||||
|
// either way if we don't have a `user_id` then the condition can't
|
||||||
|
// match.
|
||||||
|
let user_id = if let Some(user_id) = user_id {
|
||||||
|
user_id
|
||||||
|
} else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
match &**pattern_type {
|
||||||
|
"user_id" => user_id,
|
||||||
|
"user_localpart" => get_localpart_from_id(user_id)?,
|
||||||
|
_ => return Ok(false),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) {
|
||||||
|
haystack
|
||||||
|
} else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
// For the content.body we match against "words", but for everything
|
||||||
|
// else we match against the entire value.
|
||||||
|
let match_type = if event_match.key == "content.body" {
|
||||||
|
GlobMatchType::Word
|
||||||
|
} else {
|
||||||
|
GlobMatchType::Whole
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut compiled_pattern = get_glob_matcher(pattern, match_type)?;
|
||||||
|
compiled_pattern.is_match(haystack)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Match the member count against an 'is' condition
|
||||||
|
/// The `is` condition can be things like '>2', '==3' or even just '4'.
|
||||||
|
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
|
||||||
|
let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?;
|
||||||
|
let ineq = captures.get(1).map_or("==", |m| m.as_str());
|
||||||
|
let rhs: u64 = captures
|
||||||
|
.get(2)
|
||||||
|
.context("missing number")?
|
||||||
|
.as_str()
|
||||||
|
.parse()?;
|
||||||
|
|
||||||
|
let matches = match ineq {
|
||||||
|
"" | "==" => self.room_member_count == rhs,
|
||||||
|
"<" => self.room_member_count < rhs,
|
||||||
|
">" => self.room_member_count > rhs,
|
||||||
|
">=" => self.room_member_count >= rhs,
|
||||||
|
"<=" => self.room_member_count <= rhs,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(matches)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_rule_evaluator() {
|
||||||
|
let mut flattened_keys = BTreeMap::new();
|
||||||
|
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
|
||||||
|
let evaluator = PushRuleEvaluator::py_new(
|
||||||
|
flattened_keys,
|
||||||
|
10,
|
||||||
|
Some(0),
|
||||||
|
BTreeMap::new(),
|
||||||
|
BTreeMap::new(),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
|
||||||
|
assert_eq!(result.len(), 3);
|
||||||
|
}
|
|
@ -42,7 +42,6 @@
|
||||||
//!
|
//!
|
||||||
//! The set of "base rules" are the list of rules that every user has by default. A
|
//! The set of "base rules" are the list of rules that every user has by default. A
|
||||||
//! user can modify their copy of the push rules in one of three ways:
|
//! user can modify their copy of the push rules in one of three ways:
|
||||||
//!
|
|
||||||
//! 1. Adding a new push rule of a certain kind
|
//! 1. Adding a new push rule of a certain kind
|
||||||
//! 2. Changing the actions of a base rule
|
//! 2. Changing the actions of a base rule
|
||||||
//! 3. Enabling/disabling a base rule.
|
//! 3. Enabling/disabling a base rule.
|
||||||
|
@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet};
|
||||||
use anyhow::{Context, Error};
|
use anyhow::{Context, Error};
|
||||||
use log::warn;
|
use log::warn;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pythonize::pythonize;
|
use pythonize::{depythonize, pythonize};
|
||||||
use serde::de::Error as _;
|
use serde::de::Error as _;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use self::evaluator::PushRuleEvaluator;
|
||||||
|
|
||||||
mod base_rules;
|
mod base_rules;
|
||||||
|
pub mod evaluator;
|
||||||
|
pub mod utils;
|
||||||
|
|
||||||
/// Called when registering modules with python.
|
/// Called when registering modules with python.
|
||||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
|
@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
child_module.add_class::<PushRule>()?;
|
child_module.add_class::<PushRule>()?;
|
||||||
child_module.add_class::<PushRules>()?;
|
child_module.add_class::<PushRules>()?;
|
||||||
child_module.add_class::<FilteredPushRules>()?;
|
child_module.add_class::<FilteredPushRules>()?;
|
||||||
|
child_module.add_class::<PushRuleEvaluator>()?;
|
||||||
child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
|
child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
|
||||||
|
|
||||||
m.add_submodule(child_module)?;
|
m.add_submodule(child_module)?;
|
||||||
|
@ -274,6 +278,8 @@ pub enum KnownCondition {
|
||||||
#[serde(rename = "org.matrix.msc3772.relation_match")]
|
#[serde(rename = "org.matrix.msc3772.relation_match")]
|
||||||
RelationMatch {
|
RelationMatch {
|
||||||
rel_type: Cow<'static, str>,
|
rel_type: Cow<'static, str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||||
|
event_type_pattern: Option<Cow<'static, str>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
sender: Option<Cow<'static, str>>,
|
sender: Option<Cow<'static, str>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
@ -287,20 +293,26 @@ impl IntoPy<PyObject> for Condition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'source> FromPyObject<'source> for Condition {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
Ok(depythonize(ob)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The body of a [`Condition::EventMatch`]
|
/// The body of a [`Condition::EventMatch`]
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct EventMatchCondition {
|
pub struct EventMatchCondition {
|
||||||
key: Cow<'static, str>,
|
pub key: Cow<'static, str>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pattern: Option<Cow<'static, str>>,
|
pub pattern: Option<Cow<'static, str>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pattern_type: Option<Cow<'static, str>>,
|
pub pattern_type: Option<Cow<'static, str>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The collection of push rules for a user.
|
/// The collection of push rules for a user.
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
#[pyclass(frozen)]
|
#[pyclass(frozen)]
|
||||||
struct PushRules {
|
pub struct PushRules {
|
||||||
/// Custom push rules that override a base rule.
|
/// Custom push rules that override a base rule.
|
||||||
overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
|
overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
|
||||||
|
|
||||||
|
@ -319,7 +331,7 @@ struct PushRules {
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PushRules {
|
impl PushRules {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(rules: Vec<PushRule>) -> PushRules {
|
pub fn new(rules: Vec<PushRule>) -> PushRules {
|
||||||
let mut push_rules: PushRules = Default::default();
|
let mut push_rules: PushRules = Default::default();
|
||||||
|
|
||||||
for rule in rules {
|
for rule in rules {
|
||||||
|
@ -396,7 +408,7 @@ pub struct FilteredPushRules {
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl FilteredPushRules {
|
impl FilteredPushRules {
|
||||||
#[new]
|
#[new]
|
||||||
fn py_new(
|
pub fn py_new(
|
||||||
push_rules: PushRules,
|
push_rules: PushRules,
|
||||||
enabled_map: BTreeMap<String, bool>,
|
enabled_map: BTreeMap<String, bool>,
|
||||||
msc3786_enabled: bool,
|
msc3786_enabled: bool,
|
||||||
|
|
|
@ -0,0 +1,215 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::bail;
|
||||||
|
use anyhow::Context;
|
||||||
|
use anyhow::Error;
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use regex;
|
||||||
|
use regex::Regex;
|
||||||
|
use regex::RegexBuilder;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
/// Matches runs of non-wildcard characters followed by wildcard characters.
|
||||||
|
static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the localpart from a Matrix style ID
|
||||||
|
pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> {
|
||||||
|
let (localpart, _) = id
|
||||||
|
.split_once(':')
|
||||||
|
.with_context(|| format!("ID does not contain colon: {id}"))?;
|
||||||
|
|
||||||
|
// We need to strip off the first character, which is the ID type.
|
||||||
|
if localpart.is_empty() {
|
||||||
|
bail!("Invalid ID {id}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(&localpart[1..])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used by `glob_to_regex` to specify what to match the regex against.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum GlobMatchType {
|
||||||
|
/// The generated regex will match against the entire input.
|
||||||
|
Whole,
|
||||||
|
/// The generated regex will match against words.
|
||||||
|
Word,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a "glob" style expression to a regex, anchoring either to the entire
|
||||||
|
/// input or to individual words.
|
||||||
|
pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result<Regex, Error> {
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
|
||||||
|
// Patterns with wildcards must be simplified to avoid performance cliffs
|
||||||
|
// - The glob `?**?**?` is equivalent to the glob `???*`
|
||||||
|
// - The glob `???*` is equivalent to the regex `.{3,}`
|
||||||
|
for captures in WILDCARD_RUN.captures_iter(glob) {
|
||||||
|
if let Some(chunk) = captures.get(1) {
|
||||||
|
chunks.push(regex::escape(chunk.as_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(wildcards) = captures.get(2) {
|
||||||
|
if wildcards.as_str() == "" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count();
|
||||||
|
|
||||||
|
if wildcards.as_str().contains('*') {
|
||||||
|
chunks.push(format!(".{{{question_marks},}}"));
|
||||||
|
} else {
|
||||||
|
chunks.push(format!(".{{{question_marks}}}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let joined = chunks.join("");
|
||||||
|
|
||||||
|
let regex_str = match match_type {
|
||||||
|
GlobMatchType::Whole => format!(r"\A{joined}\z"),
|
||||||
|
|
||||||
|
// `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word
|
||||||
|
// character.
|
||||||
|
GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(RegexBuilder::new(®ex_str)
|
||||||
|
.case_insensitive(true)
|
||||||
|
.build()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compiles the glob into a `Matcher`.
|
||||||
|
pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result<Matcher, Error> {
|
||||||
|
// There are a number of shortcuts we can make if the glob doesn't contain a
|
||||||
|
// wild card.
|
||||||
|
let matcher = if glob.contains(['*', '?']) {
|
||||||
|
let regex = glob_to_regex(glob, match_type)?;
|
||||||
|
Matcher::Regex(regex)
|
||||||
|
} else if match_type == GlobMatchType::Whole {
|
||||||
|
// If there aren't any wildcards and we're matching the whole thing,
|
||||||
|
// then we simply can do a case-insensitive string match.
|
||||||
|
Matcher::Whole(glob.to_lowercase())
|
||||||
|
} else {
|
||||||
|
// Otherwise, if we're matching against words then can first check
|
||||||
|
// if the haystack contains the glob at all.
|
||||||
|
Matcher::Word {
|
||||||
|
word: glob.to_lowercase(),
|
||||||
|
regex: None,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(matcher)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Matches against a glob
|
||||||
|
pub enum Matcher {
|
||||||
|
/// Plain regex matching.
|
||||||
|
Regex(Regex),
|
||||||
|
|
||||||
|
/// Case-insensitive equality.
|
||||||
|
Whole(String),
|
||||||
|
|
||||||
|
/// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word.
|
||||||
|
Word { word: String, regex: Option<Regex> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Matcher {
|
||||||
|
/// Checks if the glob matches the given haystack.
|
||||||
|
pub fn is_match(&mut self, haystack: &str) -> Result<bool, Error> {
|
||||||
|
// We want to to do case-insensitive matching, so we convert to
|
||||||
|
// lowercase first.
|
||||||
|
let haystack = haystack.to_lowercase();
|
||||||
|
|
||||||
|
match self {
|
||||||
|
Matcher::Regex(regex) => Ok(regex.is_match(&haystack)),
|
||||||
|
Matcher::Whole(whole) => Ok(whole == &haystack),
|
||||||
|
Matcher::Word { word, regex } => {
|
||||||
|
// If we're looking for a literal word, then we first check if
|
||||||
|
// the haystack contains the word as a substring.
|
||||||
|
if !haystack.contains(&*word) {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it does contain the word as a substring, then we need to
|
||||||
|
// check if it is an actual word by testing it against the regex.
|
||||||
|
let regex = if let Some(regex) = regex {
|
||||||
|
regex
|
||||||
|
} else {
|
||||||
|
let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?;
|
||||||
|
regex.insert(compiled_regex)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(regex.is_match(&haystack))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_domain_from_id() {
|
||||||
|
get_localpart_from_id("").unwrap_err();
|
||||||
|
get_localpart_from_id(":").unwrap_err();
|
||||||
|
get_localpart_from_id(":asd").unwrap_err();
|
||||||
|
get_localpart_from_id("::as::asad").unwrap_err();
|
||||||
|
|
||||||
|
assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test");
|
||||||
|
assert_eq!(get_localpart_from_id("@:").unwrap(), "");
|
||||||
|
assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tset_glob() -> Result<(), Error> {
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("simple", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Asimple\z"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Asimple.{0,}\z"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Asimple.{1}\z"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Asimple.{2,}\z"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Asimple.{3}\z"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(),
|
||||||
|
r"\Aescape\.\z"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple"));
|
||||||
|
assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples"));
|
||||||
|
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples"));
|
||||||
|
assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples"));
|
||||||
|
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple"));
|
||||||
|
|
||||||
|
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple."));
|
||||||
|
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple"));
|
||||||
|
assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples"));
|
||||||
|
|
||||||
|
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test"));
|
||||||
|
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union
|
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -35,3 +35,20 @@ class FilteredPushRules:
|
||||||
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
|
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
|
||||||
|
|
||||||
def get_base_rule_ids() -> Collection[str]: ...
|
def get_base_rule_ids() -> Collection[str]: ...
|
||||||
|
|
||||||
|
class PushRuleEvaluator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
flattened_keys: Mapping[str, str],
|
||||||
|
room_member_count: int,
|
||||||
|
sender_power_level: Optional[int],
|
||||||
|
notification_power_levels: Mapping[str, int],
|
||||||
|
relations: Mapping[str, Set[Tuple[str, str]]],
|
||||||
|
relation_match_enabled: bool,
|
||||||
|
): ...
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
push_rules: FilteredPushRules,
|
||||||
|
user_id: Optional[str],
|
||||||
|
display_name: Optional[str],
|
||||||
|
) -> Collection[dict]: ...
|
||||||
|
|
|
@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = {
|
||||||
"redactions": ["have_censored"],
|
"redactions": ["have_censored"],
|
||||||
"room_stats_state": ["is_federatable"],
|
"room_stats_state": ["is_federatable"],
|
||||||
"local_media_repository": ["safe_from_quarantine"],
|
"local_media_repository": ["safe_from_quarantine"],
|
||||||
"users": ["shadow_banned"],
|
"users": ["shadow_banned", "approved"],
|
||||||
"e2e_fallback_keys_json": ["used"],
|
"e2e_fallback_keys_json": ["used"],
|
||||||
"access_tokens": ["used"],
|
"access_tokens": ["used"],
|
||||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||||
|
|
|
@ -269,3 +269,14 @@ class PublicRoomsFilterFields:
|
||||||
|
|
||||||
GENERIC_SEARCH_TERM: Final = "generic_search_term"
|
GENERIC_SEARCH_TERM: Final = "generic_search_term"
|
||||||
ROOM_TYPES: Final = "room_types"
|
ROOM_TYPES: Final = "room_types"
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalNoticeMedium:
|
||||||
|
"""Identifier for the medium this server will use to serve notice of approval for a
|
||||||
|
specific user's registration.
|
||||||
|
|
||||||
|
As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md
|
||||||
|
"""
|
||||||
|
|
||||||
|
NONE = "org.matrix.msc3866.none"
|
||||||
|
EMAIL = "org.matrix.msc3866.email"
|
||||||
|
|
|
@ -106,6 +106,8 @@ class Codes(str, Enum):
|
||||||
# Part of MSC3895.
|
# Part of MSC3895.
|
||||||
UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE"
|
UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE"
|
||||||
|
|
||||||
|
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
"""An exception with integer code and message string attributes.
|
"""An exception with integer code and message string attributes.
|
||||||
|
@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError):
|
||||||
return cs_error(self.msg, self.errcode, **extra)
|
return cs_error(self.msg, self.errcode, **extra)
|
||||||
|
|
||||||
|
|
||||||
|
class NotApprovedError(SynapseError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
msg: str,
|
||||||
|
approval_notice_medium: str,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
code=403,
|
||||||
|
msg=msg,
|
||||||
|
errcode=Codes.USER_AWAITING_APPROVAL,
|
||||||
|
additional_fields={"approval_notice_medium": approval_notice_medium},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
|
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
|
||||||
"""Utility method for constructing an error response for client-server
|
"""Utility method for constructing an error response for client-server
|
||||||
interactions.
|
interactions.
|
||||||
|
|
|
@ -14,10 +14,25 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.config._base import Config
|
from synapse.config._base import Config
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
class MSC3866Config:
|
||||||
|
"""Configuration for MSC3866 (mandating approval for new users)"""
|
||||||
|
|
||||||
|
# Whether the base support for the approval process is enabled. This includes the
|
||||||
|
# ability for administrators to check and update the approval of users, even if no
|
||||||
|
# approval is currently required.
|
||||||
|
enabled: bool = False
|
||||||
|
# Whether to require that new users are approved by an admin before their account
|
||||||
|
# can be used. Note that this setting is ignored if 'enabled' is false.
|
||||||
|
require_approval_for_new_accounts: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ExperimentalConfig(Config):
|
class ExperimentalConfig(Config):
|
||||||
"""Config section for enabling experimental features"""
|
"""Config section for enabling experimental features"""
|
||||||
|
|
||||||
|
@ -97,6 +112,10 @@ class ExperimentalConfig(Config):
|
||||||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||||
|
|
||||||
|
# MSC3866: M_USER_AWAITING_APPROVAL error code
|
||||||
|
raw_msc3866_config = experimental.get("msc3866", {})
|
||||||
|
self.msc3866 = MSC3866Config(**raw_msc3866_config)
|
||||||
|
|
||||||
# MSC3881: Remotely toggle push notifications for another client
|
# MSC3881: Remotely toggle push notifications for another client
|
||||||
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||||
|
|
||||||
|
|
|
@ -289,6 +289,10 @@ class _EventInternalMetadata:
|
||||||
"""
|
"""
|
||||||
return self._dict.get("historical", False)
|
return self._dict.get("historical", False)
|
||||||
|
|
||||||
|
def is_notifiable(self) -> bool:
|
||||||
|
"""Whether this event can trigger a push notification"""
|
||||||
|
return not self.is_outlier() or self.is_out_of_band_membership()
|
||||||
|
|
||||||
|
|
||||||
class EventBase(metaclass=abc.ABCMeta):
|
class EventBase(metaclass=abc.ABCMeta):
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -32,6 +32,7 @@ class AdminHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self._storage_controllers = hs.get_storage_controllers()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._state_storage_controller = self._storage_controllers.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||||
|
|
||||||
async def get_whois(self, user: UserID) -> JsonDict:
|
async def get_whois(self, user: UserID) -> JsonDict:
|
||||||
connections = []
|
connections = []
|
||||||
|
@ -75,6 +76,10 @@ class AdminHandler:
|
||||||
"is_guest",
|
"is_guest",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self._msc3866_enabled:
|
||||||
|
# Only include the approved flag if support for MSC3866 is enabled.
|
||||||
|
user_info_to_return.add("approved")
|
||||||
|
|
||||||
# Restrict returned keys to a known set.
|
# Restrict returned keys to a known set.
|
||||||
user_info_dict = {
|
user_info_dict = {
|
||||||
key: value
|
key: value
|
||||||
|
|
|
@ -1009,6 +1009,17 @@ class AuthHandler:
|
||||||
return res[0]
|
return res[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def is_user_approved(self, user_id: str) -> bool:
|
||||||
|
"""Checks if a user is approved and therefore can be allowed to log in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user to check the approval status of.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean that is True if the user is approved, False otherwise.
|
||||||
|
"""
|
||||||
|
return await self.store.is_user_approved(user_id)
|
||||||
|
|
||||||
async def _find_user_id_and_pwd_hash(
|
async def _find_user_id_and_pwd_hash(
|
||||||
self, user_id: str
|
self, user_id: str
|
||||||
) -> Optional[Tuple[str, str]]:
|
) -> Optional[Tuple[str, str]]:
|
||||||
|
|
|
@ -273,11 +273,9 @@ class DeviceWorkerHandler:
|
||||||
possibly_left = possibly_changed | possibly_left
|
possibly_left = possibly_changed | possibly_left
|
||||||
|
|
||||||
# Double check if we still share rooms with the given user.
|
# Double check if we still share rooms with the given user.
|
||||||
users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
|
users_rooms = await self.store.get_rooms_for_users(possibly_left)
|
||||||
possibly_left
|
|
||||||
)
|
|
||||||
for changed_user_id, entries in users_rooms.items():
|
for changed_user_id, entries in users_rooms.items():
|
||||||
if any(e.room_id in room_ids for e in entries):
|
if any(rid in room_ids for rid in entries):
|
||||||
possibly_left.discard(changed_user_id)
|
possibly_left.discard(changed_user_id)
|
||||||
else:
|
else:
|
||||||
possibly_joined.discard(changed_user_id)
|
possibly_joined.discard(changed_user_id)
|
||||||
|
@ -309,6 +307,17 @@ class DeviceWorkerHandler:
|
||||||
"self_signing_key": self_signing_key,
|
"self_signing_key": self_signing_key,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||||
|
"""Handles sending appropriate device list updates in a room that has
|
||||||
|
gone from partial to full state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(faster_joins): worker mode support
|
||||||
|
# https://github.com/matrix-org/synapse/issues/12994
|
||||||
|
logger.error(
|
||||||
|
"Trying handling device list state for partial join: not supported on workers."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeviceHandler(DeviceWorkerHandler):
|
class DeviceHandler(DeviceWorkerHandler):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
@ -746,6 +755,95 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
finally:
|
finally:
|
||||||
self._handle_new_device_update_is_processing = False
|
self._handle_new_device_update_is_processing = False
|
||||||
|
|
||||||
|
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||||
|
"""Handles sending appropriate device list updates in a room that has
|
||||||
|
gone from partial to full state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We defer to the device list updater to handle pending remote device
|
||||||
|
# list updates.
|
||||||
|
await self.device_list_updater.handle_room_un_partial_stated(room_id)
|
||||||
|
|
||||||
|
# Replay local updates.
|
||||||
|
(
|
||||||
|
join_event_id,
|
||||||
|
device_lists_stream_id,
|
||||||
|
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the local device list changes that have happened in the room since
|
||||||
|
# we started joining. If there are no updates there's nothing left to do.
|
||||||
|
changes = await self.store.get_device_list_changes_in_room(
|
||||||
|
room_id, device_lists_stream_id
|
||||||
|
)
|
||||||
|
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
|
||||||
|
if not local_changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Note: We have persisted the full state at this point, we just haven't
|
||||||
|
# cleared the `partial_room` flag.
|
||||||
|
join_state_ids = await self._state_storage.get_state_ids_for_event(
|
||||||
|
join_event_id, await_full_state=False
|
||||||
|
)
|
||||||
|
current_state_ids = await self.store.get_partial_current_state_ids(room_id)
|
||||||
|
|
||||||
|
# Now we need to work out all servers that might have been in the room
|
||||||
|
# at any point during our join.
|
||||||
|
|
||||||
|
# First we look for any membership states that have changed between the
|
||||||
|
# initial join and now...
|
||||||
|
all_keys = set(join_state_ids)
|
||||||
|
all_keys.update(current_state_ids)
|
||||||
|
|
||||||
|
potentially_changed_hosts = set()
|
||||||
|
for etype, state_key in all_keys:
|
||||||
|
if etype != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev = join_state_ids.get((etype, state_key))
|
||||||
|
current = current_state_ids.get((etype, state_key))
|
||||||
|
|
||||||
|
if prev != current:
|
||||||
|
potentially_changed_hosts.add(get_domain_from_id(state_key))
|
||||||
|
|
||||||
|
# ... then we add all the hosts that are currently joined to the room...
|
||||||
|
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
|
||||||
|
potentially_changed_hosts.update(current_hosts_in_room)
|
||||||
|
|
||||||
|
# ... and finally we remove any hosts that we were told about, as we
|
||||||
|
# will have sent device list updates to those hosts when they happened.
|
||||||
|
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
potentially_changed_hosts.difference_update(known_hosts_at_join)
|
||||||
|
|
||||||
|
potentially_changed_hosts.discard(self.server_name)
|
||||||
|
|
||||||
|
if not potentially_changed_hosts:
|
||||||
|
# Nothing to do.
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Found %d changed hosts to send device list updates to",
|
||||||
|
len(potentially_changed_hosts),
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, device_id in local_changes:
|
||||||
|
await self.store.add_device_list_outbound_pokes(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
room_id=room_id,
|
||||||
|
stream_id=None,
|
||||||
|
hosts=potentially_changed_hosts,
|
||||||
|
context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify things that device lists need to be sent out.
|
||||||
|
self.notifier.notify_replication()
|
||||||
|
for host in potentially_changed_hosts:
|
||||||
|
self.federation_sender.send_device_messages(host, immediate=False)
|
||||||
|
|
||||||
|
|
||||||
def _update_device_from_client_ips(
|
def _update_device_from_client_ips(
|
||||||
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|
||||||
|
@ -836,6 +934,16 @@ class DeviceListUpdater:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check if we are partially joining any rooms. If so we need to store
|
||||||
|
# all device list updates so that we can handle them correctly once we
|
||||||
|
# know who is in the room.
|
||||||
|
partial_rooms = await self.store.get_partial_state_rooms_and_servers()
|
||||||
|
if partial_rooms:
|
||||||
|
await self.store.add_remote_device_list_to_pending(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
|
||||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||||
if not room_ids:
|
if not room_ids:
|
||||||
# We don't share any rooms with this user. Ignore update, as we
|
# We don't share any rooms with this user. Ignore update, as we
|
||||||
|
@ -1175,3 +1283,35 @@ class DeviceListUpdater:
|
||||||
device_ids.append(verify_key.version)
|
device_ids.append(verify_key.version)
|
||||||
|
|
||||||
return device_ids
|
return device_ids
|
||||||
|
|
||||||
|
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||||
|
"""Handles sending appropriate device list updates in a room that has
|
||||||
|
gone from partial to full state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pending_updates = (
|
||||||
|
await self.store.get_pending_remote_device_list_updates_for_room(room_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, device_id in pending_updates:
|
||||||
|
logger.info(
|
||||||
|
"Got pending device list update in room %s: %s / %s",
|
||||||
|
room_id,
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
position = await self.store.add_device_change_to_streams(
|
||||||
|
user_id,
|
||||||
|
[device_id],
|
||||||
|
room_ids=[room_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not position:
|
||||||
|
# This should only happen if there are no updates, which
|
||||||
|
# shouldn't happen when we've passed in a non-empty set of
|
||||||
|
# device IDs.
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.device_handler.notifier.on_new_event(
|
||||||
|
StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
|
||||||
|
)
|
||||||
|
|
|
@ -38,7 +38,7 @@ from signedjson.sign import verify_signed_json
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
CodeMessageException,
|
CodeMessageException,
|
||||||
|
@ -149,6 +149,8 @@ class FederationHandler:
|
||||||
self.http_client = hs.get_proxied_blacklisted_http_client()
|
self.http_client = hs.get_proxied_blacklisted_http_client()
|
||||||
self._replication = hs.get_replication_data_handler()
|
self._replication = hs.get_replication_data_handler()
|
||||||
self._federation_event_handler = hs.get_federation_event_handler()
|
self._federation_event_handler = hs.get_federation_event_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
|
||||||
|
|
||||||
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
|
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
|
||||||
hs
|
hs
|
||||||
|
@ -209,7 +211,7 @@ class FederationHandler:
|
||||||
current_depth: int,
|
current_depth: int,
|
||||||
limit: int,
|
limit: int,
|
||||||
*,
|
*,
|
||||||
processing_start_time: int,
|
processing_start_time: Optional[int],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks whether the `current_depth` is at or approaching any backfill
|
Checks whether the `current_depth` is at or approaching any backfill
|
||||||
|
@ -221,12 +223,23 @@ class FederationHandler:
|
||||||
room_id: The room to backfill in.
|
room_id: The room to backfill in.
|
||||||
current_depth: The depth to check at for any upcoming backfill points.
|
current_depth: The depth to check at for any upcoming backfill points.
|
||||||
limit: The max number of events to request from the remote federated server.
|
limit: The max number of events to request from the remote federated server.
|
||||||
processing_start_time: The time when `maybe_backfill` started
|
processing_start_time: The time when `maybe_backfill` started processing.
|
||||||
processing. Only used for timing.
|
Only used for timing. If `None`, no timing observation will be made.
|
||||||
"""
|
"""
|
||||||
backwards_extremities = [
|
backwards_extremities = [
|
||||||
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
|
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
|
||||||
for event_id, depth in await self.store.get_backfill_points_in_room(room_id)
|
for event_id, depth in await self.store.get_backfill_points_in_room(
|
||||||
|
room_id=room_id,
|
||||||
|
current_depth=current_depth,
|
||||||
|
# We only need to end up with 5 extremities combined with the
|
||||||
|
# insertion event extremities to make the `/backfill` request
|
||||||
|
# but fetch an order of magnitude more to make sure there is
|
||||||
|
# enough even after we filter them by whether visible in the
|
||||||
|
# history. This isn't fool-proof as all backfill points within
|
||||||
|
# our limit could be filtered out but seems like a good amount
|
||||||
|
# to try with at least.
|
||||||
|
limit=50,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
|
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
|
||||||
|
@ -234,7 +247,12 @@ class FederationHandler:
|
||||||
insertion_events_to_be_backfilled = [
|
insertion_events_to_be_backfilled = [
|
||||||
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
|
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
|
||||||
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
|
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
room_id
|
room_id=room_id,
|
||||||
|
current_depth=current_depth,
|
||||||
|
# We only need to end up with 5 extremities combined with
|
||||||
|
# the backfill points to make the `/backfill` request ...
|
||||||
|
# (see the other comment above for more context).
|
||||||
|
limit=50,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -243,10 +261,6 @@ class FederationHandler:
|
||||||
insertion_events_to_be_backfilled,
|
insertion_events_to_be_backfilled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not backwards_extremities and not insertion_events_to_be_backfilled:
|
|
||||||
logger.debug("Not backfilling as no extremeties found.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# we now have a list of potential places to backpaginate from. We prefer to
|
# we now have a list of potential places to backpaginate from. We prefer to
|
||||||
# start with the most recent (ie, max depth), so let's sort the list.
|
# start with the most recent (ie, max depth), so let's sort the list.
|
||||||
sorted_backfill_points: List[_BackfillPoint] = sorted(
|
sorted_backfill_points: List[_BackfillPoint] = sorted(
|
||||||
|
@ -267,6 +281,33 @@ class FederationHandler:
|
||||||
sorted_backfill_points,
|
sorted_backfill_points,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If we have no backfill points lower than the `current_depth` then
|
||||||
|
# either we can a) bail or b) still attempt to backfill. We opt to try
|
||||||
|
# backfilling anyway just in case we do get relevant events.
|
||||||
|
if not sorted_backfill_points and current_depth != MAX_DEPTH:
|
||||||
|
logger.debug(
|
||||||
|
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
|
||||||
|
)
|
||||||
|
return await self._maybe_backfill_inner(
|
||||||
|
room_id=room_id,
|
||||||
|
# We use `MAX_DEPTH` so that we find all backfill points next
|
||||||
|
# time (all events are below the `MAX_DEPTH`)
|
||||||
|
current_depth=MAX_DEPTH,
|
||||||
|
limit=limit,
|
||||||
|
# We don't want to start another timing observation from this
|
||||||
|
# nested recursive call. The top-most call can record the time
|
||||||
|
# overall otherwise the smaller one will throw off the results.
|
||||||
|
processing_start_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Even after recursing with `MAX_DEPTH`, we didn't find any
|
||||||
|
# backward extremities to backfill from.
|
||||||
|
if not sorted_backfill_points:
|
||||||
|
logger.debug(
|
||||||
|
"_maybe_backfill_inner: Not backfilling as no backward extremeties found."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# If we're approaching an extremity we trigger a backfill, otherwise we
|
# If we're approaching an extremity we trigger a backfill, otherwise we
|
||||||
# no-op.
|
# no-op.
|
||||||
#
|
#
|
||||||
|
@ -276,47 +317,16 @@ class FederationHandler:
|
||||||
# chose more than one times the limit in case of failure, but choosing a
|
# chose more than one times the limit in case of failure, but choosing a
|
||||||
# much larger factor will result in triggering a backfill request much
|
# much larger factor will result in triggering a backfill request much
|
||||||
# earlier than necessary.
|
# earlier than necessary.
|
||||||
#
|
max_depth_of_backfill_points = sorted_backfill_points[0].depth
|
||||||
# XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
|
if current_depth - 2 * limit > max_depth_of_backfill_points:
|
||||||
# care about events that have happened after our current position.
|
|
||||||
#
|
|
||||||
max_depth = sorted_backfill_points[0].depth
|
|
||||||
if current_depth - 2 * limit > max_depth:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Not backfilling as we don't need to. %d < %d - 2 * %d",
|
"Not backfilling as we don't need to. %d < %d - 2 * %d",
|
||||||
max_depth,
|
max_depth_of_backfill_points,
|
||||||
current_depth,
|
current_depth,
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# We ignore extremities that have a greater depth than our current depth
|
|
||||||
# as:
|
|
||||||
# 1. we don't really care about getting events that have happened
|
|
||||||
# after our current position; and
|
|
||||||
# 2. we have likely previously tried and failed to backfill from that
|
|
||||||
# extremity, so to avoid getting "stuck" requesting the same
|
|
||||||
# backfill repeatedly we drop those extremities.
|
|
||||||
#
|
|
||||||
# However, we need to check that the filtered extremities are non-empty.
|
|
||||||
# If they are empty then either we can a) bail or b) still attempt to
|
|
||||||
# backfill. We opt to try backfilling anyway just in case we do get
|
|
||||||
# relevant events.
|
|
||||||
#
|
|
||||||
filtered_sorted_backfill_points = [
|
|
||||||
t for t in sorted_backfill_points if t.depth <= current_depth
|
|
||||||
]
|
|
||||||
if filtered_sorted_backfill_points:
|
|
||||||
logger.debug(
|
|
||||||
"_maybe_backfill_inner: backfill points before current depth: %s",
|
|
||||||
filtered_sorted_backfill_points,
|
|
||||||
)
|
|
||||||
sorted_backfill_points = filtered_sorted_backfill_points
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
|
|
||||||
)
|
|
||||||
|
|
||||||
# For performance's sake, we only want to paginate from a particular extremity
|
# For performance's sake, we only want to paginate from a particular extremity
|
||||||
# if we can actually see the events we'll get. Otherwise, we'd just spend a lot
|
# if we can actually see the events we'll get. Otherwise, we'd just spend a lot
|
||||||
# of resources to get redacted events. We check each extremity in turn and
|
# of resources to get redacted events. We check each extremity in turn and
|
||||||
|
@ -402,11 +412,22 @@ class FederationHandler:
|
||||||
# First we try hosts that are already in the room.
|
# First we try hosts that are already in the room.
|
||||||
# TODO: HEURISTIC ALERT.
|
# TODO: HEURISTIC ALERT.
|
||||||
likely_domains = (
|
likely_domains = (
|
||||||
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
|
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def try_backfill(domains: Collection[str]) -> bool:
|
async def try_backfill(domains: Collection[str]) -> bool:
|
||||||
# TODO: Should we try multiple of these at a time?
|
# TODO: Should we try multiple of these at a time?
|
||||||
|
|
||||||
|
# Number of contacted remote homeservers that have denied our backfill
|
||||||
|
# request with a 4xx code.
|
||||||
|
denied_count = 0
|
||||||
|
|
||||||
|
# Maximum number of contacted remote homeservers that can deny our
|
||||||
|
# backfill request with 4xx codes before we give up.
|
||||||
|
max_denied_count = 5
|
||||||
|
|
||||||
for dom in domains:
|
for dom in domains:
|
||||||
# We don't want to ask our own server for information we don't have
|
# We don't want to ask our own server for information we don't have
|
||||||
if dom == self.server_name:
|
if dom == self.server_name:
|
||||||
|
@ -425,13 +446,33 @@ class FederationHandler:
|
||||||
continue
|
continue
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if 400 <= e.code < 500:
|
if 400 <= e.code < 500:
|
||||||
raise e.to_synapse_error()
|
logger.warning(
|
||||||
|
"Backfill denied from %s because %s [%d/%d]",
|
||||||
|
dom,
|
||||||
|
e,
|
||||||
|
denied_count,
|
||||||
|
max_denied_count,
|
||||||
|
)
|
||||||
|
denied_count += 1
|
||||||
|
if denied_count >= max_denied_count:
|
||||||
|
return False
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||||
continue
|
continue
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
if 400 <= e.code < 500:
|
if 400 <= e.code < 500:
|
||||||
raise
|
logger.warning(
|
||||||
|
"Backfill denied from %s because %s [%d/%d]",
|
||||||
|
dom,
|
||||||
|
e,
|
||||||
|
denied_count,
|
||||||
|
max_denied_count,
|
||||||
|
)
|
||||||
|
denied_count += 1
|
||||||
|
if denied_count >= max_denied_count:
|
||||||
|
return False
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||||
continue
|
continue
|
||||||
|
@ -450,6 +491,11 @@ class FederationHandler:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# If we have the `processing_start_time`, then we can make an
|
||||||
|
# observation. We wouldn't have the `processing_start_time` in the case
|
||||||
|
# where `_maybe_backfill_inner` is recursively called to find any
|
||||||
|
# backfill points regardless of `current_depth`.
|
||||||
|
if processing_start_time is not None:
|
||||||
processing_end_time = self.clock.time_msec()
|
processing_end_time = self.clock.time_msec()
|
||||||
backfill_processing_before_timer.observe(
|
backfill_processing_before_timer.observe(
|
||||||
(processing_end_time - processing_start_time) / 1000
|
(processing_end_time - processing_start_time) / 1000
|
||||||
|
@ -956,9 +1002,15 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
context = EventContext.for_outlier(self._storage_controllers)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
|
|
||||||
|
await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
|
||||||
|
try:
|
||||||
await self._federation_event_handler.persist_events_and_notify(
|
await self._federation_event_handler.persist_events_and_notify(
|
||||||
event.room_id, [(event, context)]
|
event.room_id, [(event, context)]
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||||
|
raise
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -1624,6 +1676,9 @@ class FederationHandler:
|
||||||
# https://github.com/matrix-org/synapse/issues/12994
|
# https://github.com/matrix-org/synapse/issues/12994
|
||||||
await self.state_handler.update_current_state(room_id)
|
await self.state_handler.update_current_state(room_id)
|
||||||
|
|
||||||
|
logger.info("Handling any pending device list updates")
|
||||||
|
await self._device_handler.handle_room_un_partial_stated(room_id)
|
||||||
|
|
||||||
logger.info("Clearing partial-state flag for %s", room_id)
|
logger.info("Clearing partial-state flag for %s", room_id)
|
||||||
success = await self.store.clear_partial_state_room(room_id)
|
success = await self.store.clear_partial_state_room(room_id)
|
||||||
if success:
|
if success:
|
||||||
|
|
|
@ -2170,6 +2170,7 @@ class FederationEventHandler:
|
||||||
if instance != self._instance_name:
|
if instance != self._instance_name:
|
||||||
# Limit the number of events sent over replication. We choose 200
|
# Limit the number of events sent over replication. We choose 200
|
||||||
# here as that is what we default to in `max_request_body_size(..)`
|
# here as that is what we default to in `max_request_body_size(..)`
|
||||||
|
result = {}
|
||||||
try:
|
try:
|
||||||
for batch in batch_iter(event_and_contexts, 200):
|
for batch in batch_iter(event_and_contexts, 200):
|
||||||
result = await self._send_events(
|
result = await self._send_events(
|
||||||
|
|
|
@ -220,6 +220,7 @@ class RegistrationHandler:
|
||||||
by_admin: bool = False,
|
by_admin: bool = False,
|
||||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
|
approved: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -246,6 +247,8 @@ class RegistrationHandler:
|
||||||
user_agent_ips: Tuples of user-agents and IP addresses used
|
user_agent_ips: Tuples of user-agents and IP addresses used
|
||||||
during the registration process.
|
during the registration process.
|
||||||
auth_provider_id: The SSO IdP the user used, if any.
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
|
approved: True if the new user should be considered already
|
||||||
|
approved by an administrator.
|
||||||
Returns:
|
Returns:
|
||||||
The registered user_id.
|
The registered user_id.
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -307,6 +310,7 @@ class RegistrationHandler:
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
address=address,
|
address=address,
|
||||||
shadow_banned=shadow_banned,
|
shadow_banned=shadow_banned,
|
||||||
|
approved=approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
profile = await self.store.get_profileinfo(localpart)
|
profile = await self.store.get_profileinfo(localpart)
|
||||||
|
@ -695,6 +699,7 @@ class RegistrationHandler:
|
||||||
user_type: Optional[str] = None,
|
user_type: Optional[str] = None,
|
||||||
address: Optional[str] = None,
|
address: Optional[str] = None,
|
||||||
shadow_banned: bool = False,
|
shadow_banned: bool = False,
|
||||||
|
approved: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register user in the datastore.
|
"""Register user in the datastore.
|
||||||
|
|
||||||
|
@ -713,6 +718,7 @@ class RegistrationHandler:
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
address: the IP address used to perform the registration.
|
address: the IP address used to perform the registration.
|
||||||
shadow_banned: Whether to shadow-ban the user
|
shadow_banned: Whether to shadow-ban the user
|
||||||
|
approved: Whether to mark the user as approved by an administrator
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker.worker_app:
|
if self.hs.config.worker.worker_app:
|
||||||
await self._register_client(
|
await self._register_client(
|
||||||
|
@ -726,6 +732,7 @@ class RegistrationHandler:
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
address=address,
|
address=address,
|
||||||
shadow_banned=shadow_banned,
|
shadow_banned=shadow_banned,
|
||||||
|
approved=approved,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self.store.register_user(
|
await self.store.register_user(
|
||||||
|
@ -738,6 +745,7 @@ class RegistrationHandler:
|
||||||
admin=admin,
|
admin=admin,
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
shadow_banned=shadow_banned,
|
shadow_banned=shadow_banned,
|
||||||
|
approved=approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only call the account validity module(s) on the main process, to avoid
|
# Only call the account validity module(s) on the main process, to avoid
|
||||||
|
|
|
@ -1540,7 +1540,9 @@ class TimestampLookupHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
likely_domains = (
|
likely_domains = (
|
||||||
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
|
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Loop through each homeserver candidate until we get a succesful response
|
# Loop through each homeserver candidate until we get a succesful response
|
||||||
|
|
|
@ -187,6 +187,19 @@ class SendEmailHandler:
|
||||||
multipart_msg["To"] = email_address
|
multipart_msg["To"] = email_address
|
||||||
multipart_msg["Date"] = email.utils.formatdate()
|
multipart_msg["Date"] = email.utils.formatdate()
|
||||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||||
|
# Discourage automatic responses to Synapse's emails.
|
||||||
|
# Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted"
|
||||||
|
# header is present with any value other than "no". See
|
||||||
|
# https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1
|
||||||
|
multipart_msg["Auto-Submitted"] = "auto-generated"
|
||||||
|
# Also include a Microsoft-Exchange specific header:
|
||||||
|
# https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1
|
||||||
|
# which suggests it can take the value "All" to "suppress all auto-replies",
|
||||||
|
# or a comma separated list of auto-reply classes to suppress.
|
||||||
|
# The following stack overflow question has a little more context:
|
||||||
|
# https://stackoverflow.com/a/25324691/5252017
|
||||||
|
# https://stackoverflow.com/a/61646381/5252017
|
||||||
|
multipart_msg["X-Auto-Response-Suppress"] = "All"
|
||||||
multipart_msg.attach(text_part)
|
multipart_msg.attach(text_part)
|
||||||
multipart_msg.attach(html_part)
|
multipart_msg.attach(html_part)
|
||||||
|
|
||||||
|
|
|
@ -1490,16 +1490,14 @@ class SyncHandler:
|
||||||
since_token.device_list_key
|
since_token.device_list_key
|
||||||
)
|
)
|
||||||
if changed_users is not None:
|
if changed_users is not None:
|
||||||
result = await self.store.get_rooms_for_users_with_stream_ordering(
|
result = await self.store.get_rooms_for_users(changed_users)
|
||||||
changed_users
|
|
||||||
)
|
|
||||||
|
|
||||||
for changed_user_id, entries in result.items():
|
for changed_user_id, entries in result.items():
|
||||||
# Check if the changed user shares any rooms with the user,
|
# Check if the changed user shares any rooms with the user,
|
||||||
# or if the changed user is the syncing user (as we always
|
# or if the changed user is the syncing user (as we always
|
||||||
# want to include device list updates of their own devices).
|
# want to include device list updates of their own devices).
|
||||||
if user_id == changed_user_id or any(
|
if user_id == changed_user_id or any(
|
||||||
e.room_id in joined_rooms for e in entries
|
rid in joined_rooms for rid in entries
|
||||||
):
|
):
|
||||||
users_that_have_changed.add(changed_user_id)
|
users_that_have_changed.add(changed_user_id)
|
||||||
else:
|
else:
|
||||||
|
@ -1533,13 +1531,9 @@ class SyncHandler:
|
||||||
newly_left_users.update(left_users)
|
newly_left_users.update(left_users)
|
||||||
|
|
||||||
# Remove any users that we still share a room with.
|
# Remove any users that we still share a room with.
|
||||||
left_users_rooms = (
|
left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
|
||||||
await self.store.get_rooms_for_users_with_stream_ordering(
|
|
||||||
newly_left_users
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for user_id, entries in left_users_rooms.items():
|
for user_id, entries in left_users_rooms.items():
|
||||||
if any(e.room_id in joined_rooms for e in entries):
|
if any(rid in joined_rooms for rid in entries):
|
||||||
newly_left_users.discard(user_id)
|
newly_left_users.discard(user_id)
|
||||||
|
|
||||||
return DeviceListUpdates(
|
return DeviceListUpdates(
|
||||||
|
|
|
@ -842,6 +842,8 @@ class ModuleApi:
|
||||||
however invalidation that needs to go to other workers needs to call `invalidate_cache`
|
however invalidation that needs to go to other workers needs to call `invalidate_cache`
|
||||||
on the module API instead.
|
on the module API instead.
|
||||||
|
|
||||||
|
Added in Synapse v1.69.0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cached_function: The cached function that will be registered to receive invalidation
|
cached_function: The cached function that will be registered to receive invalidation
|
||||||
locally and from other workers.
|
locally and from other workers.
|
||||||
|
@ -856,6 +858,8 @@ class ModuleApi:
|
||||||
"""Invalidate a cache entry of a cached function across workers. The cached function
|
"""Invalidate a cache entry of a cached function across workers. The cached function
|
||||||
needs to be registered on all workers first with `register_cached_function`.
|
needs to be registered on all workers first with `register_cached_function`.
|
||||||
|
|
||||||
|
Added in Synapse v1.69.0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cached_function: The cached function that needs an invalidation
|
cached_function: The cached function that needs an invalidation
|
||||||
keys: keys of the entry to invalidate, usually matching the arguments of the
|
keys: keys of the entry to invalidate, usually matching the arguments of the
|
||||||
|
|
|
@ -17,6 +17,7 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
from synapse.storage.databases.main.roommember import EventIdMembership
|
from synapse.storage.databases.main.roommember import EventIdMembership
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule
|
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_event_for_clients_with_state
|
from synapse.visibility import filter_event_for_clients_with_state
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
@ -173,7 +172,11 @@ class BulkPushRuleEvaluator:
|
||||||
|
|
||||||
async def _get_power_levels_and_sender_level(
|
async def _get_power_levels_and_sender_level(
|
||||||
self, event: EventBase, context: EventContext
|
self, event: EventBase, context: EventContext
|
||||||
) -> Tuple[dict, int]:
|
) -> Tuple[dict, Optional[int]]:
|
||||||
|
# There are no power levels and sender levels possible to get from outlier
|
||||||
|
if event.internal_metadata.is_outlier():
|
||||||
|
return {}, None
|
||||||
|
|
||||||
event_types = auth_types_for_event(event.room_version, event)
|
event_types = auth_types_for_event(event.room_version, event)
|
||||||
prev_state_ids = await context.get_prev_state_ids(
|
prev_state_ids = await context.get_prev_state_ids(
|
||||||
StateFilter.from_types(event_types)
|
StateFilter.from_types(event_types)
|
||||||
|
@ -250,8 +253,8 @@ class BulkPushRuleEvaluator:
|
||||||
should increment the unread count, and insert the results into the
|
should increment the unread count, and insert the results into the
|
||||||
event_push_actions_staging table.
|
event_push_actions_staging table.
|
||||||
"""
|
"""
|
||||||
if event.internal_metadata.is_outlier():
|
if not event.internal_metadata.is_notifiable():
|
||||||
# This can happen due to out of band memberships
|
# Push rules for events that aren't notifiable can't be processed by this
|
||||||
return
|
return
|
||||||
|
|
||||||
# Disable counting as unread unless the experimental configuration is
|
# Disable counting as unread unless the experimental configuration is
|
||||||
|
@ -286,11 +289,11 @@ class BulkPushRuleEvaluator:
|
||||||
if relation.rel_type == RelationTypes.THREAD:
|
if relation.rel_type == RelationTypes.THREAD:
|
||||||
thread_id = relation.parent_id
|
thread_id = relation.parent_id
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(
|
evaluator = PushRuleEvaluator(
|
||||||
event,
|
_flatten_dict(event),
|
||||||
room_member_count,
|
room_member_count,
|
||||||
sender_power_level,
|
sender_power_level,
|
||||||
power_levels,
|
power_levels.get("notifications", {}),
|
||||||
relations,
|
relations,
|
||||||
self._relations_match_enabled,
|
self._relations_match_enabled,
|
||||||
)
|
)
|
||||||
|
@ -300,20 +303,10 @@ class BulkPushRuleEvaluator:
|
||||||
event.room_id, users
|
event.room_id, users
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is a check for the case where user joins a room without being
|
|
||||||
# allowed to see history, and then the server receives a delayed event
|
|
||||||
# from before the user joined, which they should not be pushed for
|
|
||||||
uids_with_visibility = await filter_event_for_clients_with_state(
|
|
||||||
self.store, users, event, context
|
|
||||||
)
|
|
||||||
|
|
||||||
for uid, rules in rules_by_user.items():
|
for uid, rules in rules_by_user.items():
|
||||||
if event.sender == uid:
|
if event.sender == uid:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if uid not in uids_with_visibility:
|
|
||||||
continue
|
|
||||||
|
|
||||||
display_name = None
|
display_name = None
|
||||||
profile = profiles.get(uid)
|
profile = profiles.get(uid)
|
||||||
if profile:
|
if profile:
|
||||||
|
@ -334,17 +327,25 @@ class BulkPushRuleEvaluator:
|
||||||
# current user, it'll be added to the dict later.
|
# current user, it'll be added to the dict later.
|
||||||
actions_by_user[uid] = []
|
actions_by_user[uid] = []
|
||||||
|
|
||||||
for rule, enabled in rules.rules():
|
actions = evaluator.run(rules, uid, display_name)
|
||||||
if not enabled:
|
if "notify" in actions:
|
||||||
continue
|
|
||||||
|
|
||||||
matches = evaluator.check_conditions(rule.conditions, uid, display_name)
|
|
||||||
if matches:
|
|
||||||
actions = [x for x in rule.actions if x != "dont_notify"]
|
|
||||||
if actions and "notify" in actions:
|
|
||||||
# Push rules say we should notify the user of this event
|
# Push rules say we should notify the user of this event
|
||||||
actions_by_user[uid] = actions
|
actions_by_user[uid] = actions
|
||||||
break
|
|
||||||
|
# This is a check for the case where user joins a room without being
|
||||||
|
# allowed to see history, and then the server receives a delayed event
|
||||||
|
# from before the user joined, which they should not be pushed for
|
||||||
|
#
|
||||||
|
# We do this *after* calculating the push actions as a) its unlikely
|
||||||
|
# that we'll filter anyone out and b) for large rooms its likely that
|
||||||
|
# most users will have push disabled and so the set of users to check is
|
||||||
|
# much smaller.
|
||||||
|
uids_with_visibility = await filter_event_for_clients_with_state(
|
||||||
|
self.store, actions_by_user.keys(), event, context
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id in set(actions_by_user).difference(uids_with_visibility):
|
||||||
|
actions_by_user.pop(user_id, None)
|
||||||
|
|
||||||
# Mark in the DB staging area the push actions for users who should be
|
# Mark in the DB staging area the push actions for users who should be
|
||||||
# notified for this event. (This will then get handled when we persist
|
# notified for this event. (This will then get handled when we persist
|
||||||
|
@ -361,3 +362,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
|
||||||
Rule = Dict[str, dict]
|
Rule = Dict[str, dict]
|
||||||
RulesByUser = Dict[str, List[Rule]]
|
RulesByUser = Dict[str, List[Rule]]
|
||||||
StateGroup = Union[object, int]
|
StateGroup = Union[object, int]
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_dict(
|
||||||
|
d: Union[EventBase, Mapping[str, Any]],
|
||||||
|
prefix: Optional[List[str]] = None,
|
||||||
|
result: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
if prefix is None:
|
||||||
|
prefix = []
|
||||||
|
if result is None:
|
||||||
|
result = {}
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
result[".".join(prefix + [key])] = value.lower()
|
||||||
|
elif isinstance(value, Mapping):
|
||||||
|
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
|
@ -102,10 +102,8 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
|
||||||
# with PRIORITY_CLASS_INVERSE_MAP.
|
# with PRIORITY_CLASS_INVERSE_MAP.
|
||||||
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
||||||
|
|
||||||
if unscoped_rule_id:
|
|
||||||
templaterule["rule_id"] = unscoped_rule_id
|
templaterule["rule_id"] = unscoped_rule_id
|
||||||
if rule.default:
|
templaterule["default"] = rule.default
|
||||||
templaterule["default"] = True
|
|
||||||
return templaterule
|
return templaterule
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
||||||
from synapse.storage.databases.main.event_push_actions import HttpPushAction
|
from synapse.storage.databases.main.event_push_actions import HttpPushAction
|
||||||
|
|
||||||
from . import push_rule_evaluator, push_tools
|
from . import push_tools
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -56,6 +56,39 @@ http_badges_failed_counter = Counter(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Converts a list of actions into a `tweaks` dict (which can then be passed to
|
||||||
|
the push gateway).
|
||||||
|
|
||||||
|
This function ignores all actions other than `set_tweak` actions, and treats
|
||||||
|
absent `value`s as `True`, which agrees with the only spec-defined treatment
|
||||||
|
of absent `value`s (namely, for `highlight` tweaks).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: list of actions
|
||||||
|
e.g. [
|
||||||
|
{"set_tweak": "a", "value": "AAA"},
|
||||||
|
{"set_tweak": "b", "value": "BBB"},
|
||||||
|
{"set_tweak": "highlight"},
|
||||||
|
"notify"
|
||||||
|
]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary of tweaks for those actions
|
||||||
|
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
|
||||||
|
"""
|
||||||
|
tweaks = {}
|
||||||
|
for a in actions:
|
||||||
|
if not isinstance(a, dict):
|
||||||
|
continue
|
||||||
|
if "set_tweak" in a:
|
||||||
|
# value is allowed to be absent in which case the value assumed
|
||||||
|
# should be True.
|
||||||
|
tweaks[a["set_tweak"]] = a.get("value", True)
|
||||||
|
return tweaks
|
||||||
|
|
||||||
|
|
||||||
class HttpPusher(Pusher):
|
class HttpPusher(Pusher):
|
||||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||||
MAX_BACKOFF_SEC = 60 * 60
|
MAX_BACKOFF_SEC = 60 * 60
|
||||||
|
@ -286,7 +319,7 @@ class HttpPusher(Pusher):
|
||||||
if "notify" not in push_action.actions:
|
if "notify" not in push_action.actions:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
|
tweaks = tweaks_for_actions(push_action.actions)
|
||||||
badge = await push_tools.get_badge_count(
|
badge = await push_tools.get_badge_count(
|
||||||
self.hs.get_datastores().main,
|
self.hs.get_datastores().main,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
|
|
|
@ -1,361 +0,0 @@
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
|
||||||
# Copyright 2017 New Vector Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Pattern,
|
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from matrix_common.regex import glob_to_regex, to_word_pattern
|
|
||||||
|
|
||||||
from synapse.events import EventBase
|
|
||||||
from synapse.types import UserID
|
|
||||||
from synapse.util.caches.lrucache import LruCache
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
|
|
||||||
IS_GLOB = re.compile(r"[\?\*\[\]]")
|
|
||||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
|
||||||
|
|
||||||
|
|
||||||
def _room_member_count(
|
|
||||||
ev: EventBase, condition: Mapping[str, Any], room_member_count: int
|
|
||||||
) -> bool:
|
|
||||||
return _test_ineq_condition(condition, room_member_count)
|
|
||||||
|
|
||||||
|
|
||||||
def _sender_notification_permission(
|
|
||||||
ev: EventBase,
|
|
||||||
condition: Mapping[str, Any],
|
|
||||||
sender_power_level: int,
|
|
||||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
|
||||||
) -> bool:
|
|
||||||
notif_level_key = condition.get("key")
|
|
||||||
if notif_level_key is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
notif_levels = power_levels.get("notifications", {})
|
|
||||||
assert isinstance(notif_levels, dict)
|
|
||||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
|
||||||
|
|
||||||
return sender_power_level >= room_notif_level
|
|
||||||
|
|
||||||
|
|
||||||
def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
|
|
||||||
if "is" not in condition:
|
|
||||||
return False
|
|
||||||
m = INEQUALITY_EXPR.match(condition["is"])
|
|
||||||
if not m:
|
|
||||||
return False
|
|
||||||
ineq = m.group(1)
|
|
||||||
rhs = m.group(2)
|
|
||||||
if not rhs.isdigit():
|
|
||||||
return False
|
|
||||||
rhs_int = int(rhs)
|
|
||||||
|
|
||||||
if ineq == "" or ineq == "==":
|
|
||||||
return number == rhs_int
|
|
||||||
elif ineq == "<":
|
|
||||||
return number < rhs_int
|
|
||||||
elif ineq == ">":
|
|
||||||
return number > rhs_int
|
|
||||||
elif ineq == ">=":
|
|
||||||
return number >= rhs_int
|
|
||||||
elif ineq == "<=":
|
|
||||||
return number <= rhs_int
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Converts a list of actions into a `tweaks` dict (which can then be passed to
|
|
||||||
the push gateway).
|
|
||||||
|
|
||||||
This function ignores all actions other than `set_tweak` actions, and treats
|
|
||||||
absent `value`s as `True`, which agrees with the only spec-defined treatment
|
|
||||||
of absent `value`s (namely, for `highlight` tweaks).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actions: list of actions
|
|
||||||
e.g. [
|
|
||||||
{"set_tweak": "a", "value": "AAA"},
|
|
||||||
{"set_tweak": "b", "value": "BBB"},
|
|
||||||
{"set_tweak": "highlight"},
|
|
||||||
"notify"
|
|
||||||
]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dictionary of tweaks for those actions
|
|
||||||
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
|
|
||||||
"""
|
|
||||||
tweaks = {}
|
|
||||||
for a in actions:
|
|
||||||
if not isinstance(a, dict):
|
|
||||||
continue
|
|
||||||
if "set_tweak" in a:
|
|
||||||
# value is allowed to be absent in which case the value assumed
|
|
||||||
# should be True.
|
|
||||||
tweaks[a["set_tweak"]] = a.get("value", True)
|
|
||||||
return tweaks
|
|
||||||
|
|
||||||
|
|
||||||
class PushRuleEvaluatorForEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
event: EventBase,
|
|
||||||
room_member_count: int,
|
|
||||||
sender_power_level: int,
|
|
||||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
|
||||||
relations: Dict[str, Set[Tuple[str, str]]],
|
|
||||||
relations_match_enabled: bool,
|
|
||||||
):
|
|
||||||
self._event = event
|
|
||||||
self._room_member_count = room_member_count
|
|
||||||
self._sender_power_level = sender_power_level
|
|
||||||
self._power_levels = power_levels
|
|
||||||
self._relations = relations
|
|
||||||
self._relations_match_enabled = relations_match_enabled
|
|
||||||
|
|
||||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
|
||||||
self._value_cache = _flatten_dict(event)
|
|
||||||
|
|
||||||
# Maps cache keys to final values.
|
|
||||||
self._condition_cache: Dict[str, bool] = {}
|
|
||||||
|
|
||||||
def check_conditions(
|
|
||||||
self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Returns true if a user's conditions/user ID/display name match the event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conditions: The user's conditions to match.
|
|
||||||
uid: The user's MXID.
|
|
||||||
display_name: The display name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if all conditions match the event, False otherwise.
|
|
||||||
"""
|
|
||||||
for cond in conditions:
|
|
||||||
_cache_key = cond.get("_cache_key", None)
|
|
||||||
if _cache_key:
|
|
||||||
res = self._condition_cache.get(_cache_key, None)
|
|
||||||
if res is False:
|
|
||||||
return False
|
|
||||||
elif res is True:
|
|
||||||
continue
|
|
||||||
|
|
||||||
res = self.matches(cond, uid, display_name)
|
|
||||||
if _cache_key:
|
|
||||||
self._condition_cache[_cache_key] = bool(res)
|
|
||||||
|
|
||||||
if not res:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def matches(
|
|
||||||
self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Returns true if a user's condition/user ID/display name match the event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition: The user's condition to match.
|
|
||||||
uid: The user's MXID.
|
|
||||||
display_name: The display name, or None if there is not one.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the condition matches the event, False otherwise.
|
|
||||||
"""
|
|
||||||
if condition["kind"] == "event_match":
|
|
||||||
return self._event_match(condition, user_id)
|
|
||||||
elif condition["kind"] == "contains_display_name":
|
|
||||||
return self._contains_display_name(display_name)
|
|
||||||
elif condition["kind"] == "room_member_count":
|
|
||||||
return _room_member_count(self._event, condition, self._room_member_count)
|
|
||||||
elif condition["kind"] == "sender_notification_permission":
|
|
||||||
return _sender_notification_permission(
|
|
||||||
self._event, condition, self._sender_power_level, self._power_levels
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
condition["kind"] == "org.matrix.msc3772.relation_match"
|
|
||||||
and self._relations_match_enabled
|
|
||||||
):
|
|
||||||
return self._relation_match(condition, user_id)
|
|
||||||
else:
|
|
||||||
# XXX This looks incorrect -- we have reached an unknown condition
|
|
||||||
# kind and are unconditionally returning that it matches. Note
|
|
||||||
# that it seems possible to provide a condition to the /pushrules
|
|
||||||
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _event_match(self, condition: Mapping, user_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check an "event_match" push rule condition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition: The "event_match" push rule condition to match.
|
|
||||||
user_id: The user's MXID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the condition matches the event, False otherwise.
|
|
||||||
"""
|
|
||||||
pattern = condition.get("pattern", None)
|
|
||||||
|
|
||||||
if not pattern:
|
|
||||||
pattern_type = condition.get("pattern_type", None)
|
|
||||||
if pattern_type == "user_id":
|
|
||||||
pattern = user_id
|
|
||||||
elif pattern_type == "user_localpart":
|
|
||||||
pattern = UserID.from_string(user_id).localpart
|
|
||||||
|
|
||||||
if not pattern:
|
|
||||||
logger.warning("event_match condition with no pattern")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# XXX: optimisation: cache our pattern regexps
|
|
||||||
if condition["key"] == "content.body":
|
|
||||||
body = self._event.content.get("body", None)
|
|
||||||
if not body or not isinstance(body, str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return _glob_matches(pattern, body, word_boundary=True)
|
|
||||||
else:
|
|
||||||
haystack = self._value_cache.get(condition["key"], None)
|
|
||||||
if haystack is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return _glob_matches(pattern, haystack)
|
|
||||||
|
|
||||||
def _contains_display_name(self, display_name: Optional[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Check an "event_match" push rule condition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
display_name: The display name, or None if there is not one.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the display name is found in the event body, False otherwise.
|
|
||||||
"""
|
|
||||||
if not display_name:
|
|
||||||
return False
|
|
||||||
|
|
||||||
body = self._event.content.get("body", None)
|
|
||||||
if not body or not isinstance(body, str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Similar to _glob_matches, but do not treat display_name as a glob.
|
|
||||||
r = regex_cache.get((display_name, False, True), None)
|
|
||||||
if not r:
|
|
||||||
r1 = re.escape(display_name)
|
|
||||||
r1 = to_word_pattern(r1)
|
|
||||||
r = re.compile(r1, flags=re.IGNORECASE)
|
|
||||||
regex_cache[(display_name, False, True)] = r
|
|
||||||
|
|
||||||
return bool(r.search(body))
|
|
||||||
|
|
||||||
def _relation_match(self, condition: Mapping, user_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check an "relation_match" push rule condition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition: The "event_match" push rule condition to match.
|
|
||||||
user_id: The user's MXID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the condition matches the event, False otherwise.
|
|
||||||
"""
|
|
||||||
rel_type = condition.get("rel_type")
|
|
||||||
if not rel_type:
|
|
||||||
logger.warning("relation_match condition missing rel_type")
|
|
||||||
return False
|
|
||||||
|
|
||||||
sender_pattern = condition.get("sender")
|
|
||||||
if sender_pattern is None:
|
|
||||||
sender_type = condition.get("sender_type")
|
|
||||||
if sender_type == "user_id":
|
|
||||||
sender_pattern = user_id
|
|
||||||
type_pattern = condition.get("type")
|
|
||||||
|
|
||||||
# If any other relations matches, return True.
|
|
||||||
for sender, event_type in self._relations.get(rel_type, ()):
|
|
||||||
if sender_pattern and not _glob_matches(sender_pattern, sender):
|
|
||||||
continue
|
|
||||||
if type_pattern and not _glob_matches(type_pattern, event_type):
|
|
||||||
continue
|
|
||||||
# All values must have matched.
|
|
||||||
return True
|
|
||||||
|
|
||||||
# No relations matched.
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
|
||||||
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
|
|
||||||
50000, "regex_push_cache"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
|
||||||
"""Tests if value matches glob.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
glob
|
|
||||||
value: String to test against glob.
|
|
||||||
word_boundary: Whether to match against word boundaries or entire
|
|
||||||
string. Defaults to False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
r = regex_cache.get((glob, True, word_boundary), None)
|
|
||||||
if not r:
|
|
||||||
r = glob_to_regex(glob, word_boundary=word_boundary)
|
|
||||||
regex_cache[(glob, True, word_boundary)] = r
|
|
||||||
return bool(r.search(value))
|
|
||||||
except re.error:
|
|
||||||
logger.warning("Failed to parse glob to regex: %r", glob)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_dict(
|
|
||||||
d: Union[EventBase, Mapping[str, Any]],
|
|
||||||
prefix: Optional[List[str]] = None,
|
|
||||||
result: Optional[Dict[str, str]] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
if prefix is None:
|
|
||||||
prefix = []
|
|
||||||
if result is None:
|
|
||||||
result = {}
|
|
||||||
for key, value in d.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
result[".".join(prefix + [key])] = value.lower()
|
|
||||||
elif isinstance(value, Mapping):
|
|
||||||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
|
||||||
|
|
||||||
return result
|
|
|
@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
user_type: Optional[str],
|
user_type: Optional[str],
|
||||||
address: Optional[str],
|
address: Optional[str],
|
||||||
shadow_banned: bool,
|
shadow_banned: bool,
|
||||||
|
approved: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
or None for a normal user.
|
or None for a normal user.
|
||||||
address: the IP address used to perform the regitration.
|
address: the IP address used to perform the regitration.
|
||||||
shadow_banned: Whether to shadow-ban the user
|
shadow_banned: Whether to shadow-ban the user
|
||||||
|
approved: Whether the user should be considered already approved by an
|
||||||
|
administrator.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"password_hash": password_hash,
|
"password_hash": password_hash,
|
||||||
|
@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
"user_type": user_type,
|
"user_type": user_type,
|
||||||
"address": address,
|
"address": address,
|
||||||
"shadow_banned": shadow_banned,
|
"shadow_banned": shadow_banned,
|
||||||
|
"approved": approved,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
|
@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
user_type=content["user_type"],
|
user_type=content["user_type"],
|
||||||
address=content["address"],
|
address=content["address"],
|
||||||
shadow_banned=content["shadow_banned"],
|
shadow_banned=content["shadow_banned"],
|
||||||
|
approved=content["approved"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
|
@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||||
|
|
||||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
|
||||||
guests = parse_boolean(request, "guests", default=True)
|
guests = parse_boolean(request, "guests", default=True)
|
||||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||||
|
|
||||||
|
# If support for MSC3866 is not enabled, apply no filtering based on the
|
||||||
|
# `approved` column.
|
||||||
|
if self._msc3866_enabled:
|
||||||
|
approved = parse_boolean(request, "approved", default=True)
|
||||||
|
else:
|
||||||
|
approved = True
|
||||||
|
|
||||||
order_by = parse_string(
|
order_by = parse_string(
|
||||||
request,
|
request,
|
||||||
"order_by",
|
"order_by",
|
||||||
|
@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
|
||||||
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||||
|
|
||||||
users, total = await self.store.get_users_paginate(
|
users, total = await self.store.get_users_paginate(
|
||||||
start, limit, user_id, name, guests, deactivated, order_by, direction
|
start,
|
||||||
|
limit,
|
||||||
|
user_id,
|
||||||
|
name,
|
||||||
|
guests,
|
||||||
|
deactivated,
|
||||||
|
order_by,
|
||||||
|
direction,
|
||||||
|
approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If support for MSC3866 is not enabled, don't show the approval flag.
|
||||||
|
if not self._msc3866_enabled:
|
||||||
|
for user in users:
|
||||||
|
del user["approved"]
|
||||||
|
|
||||||
ret = {"users": users, "total": total}
|
ret = {"users": users, "total": total}
|
||||||
if (start + limit) < total:
|
if (start + limit) < total:
|
||||||
ret["next_token"] = str(start + len(users))
|
ret["next_token"] = str(start + len(users))
|
||||||
|
@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
|
||||||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||||
|
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: SynapseRequest, user_id: str
|
self, request: SynapseRequest, user_id: str
|
||||||
|
@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
|
||||||
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
|
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
approved: Optional[bool] = None
|
||||||
|
if "approved" in body and self._msc3866_enabled:
|
||||||
|
approved = body["approved"]
|
||||||
|
if not isinstance(approved, bool):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"'approved' parameter is not of type boolean",
|
||||||
|
)
|
||||||
|
|
||||||
# convert List[Dict[str, str]] into List[Tuple[str, str]]
|
# convert List[Dict[str, str]] into List[Tuple[str, str]]
|
||||||
if external_ids is not None:
|
if external_ids is not None:
|
||||||
new_external_ids = [
|
new_external_ids = [
|
||||||
|
@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
|
||||||
if "user_type" in body:
|
if "user_type" in body:
|
||||||
await self.store.set_user_type(target_user, user_type)
|
await self.store.set_user_type(target_user, user_type)
|
||||||
|
|
||||||
|
if approved is not None:
|
||||||
|
await self.store.update_user_approval_status(target_user, approved)
|
||||||
|
|
||||||
user = await self.admin_handler.get_user(target_user)
|
user = await self.admin_handler.get_user(target_user)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
|
|
||||||
|
@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
|
||||||
if password is not None:
|
if password is not None:
|
||||||
password_hash = await self.auth_handler.hash(password)
|
password_hash = await self.auth_handler.hash(password)
|
||||||
|
|
||||||
|
new_user_approved = True
|
||||||
|
if self._msc3866_enabled and approved is not None:
|
||||||
|
new_user_approved = approved
|
||||||
|
|
||||||
user_id = await self.registration_handler.register_user(
|
user_id = await self.registration_handler.register_user(
|
||||||
localpart=target_user.localpart,
|
localpart=target_user.localpart,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
|
@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
by_admin=True,
|
by_admin=True,
|
||||||
|
approved=new_user_approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
if threepids is not None:
|
if threepids is not None:
|
||||||
|
@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
by_admin=True,
|
by_admin=True,
|
||||||
|
approved=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await register._create_registration_details(user_id, body)
|
result = await register._create_registration_details(user_id, body)
|
||||||
|
|
|
@ -28,7 +28,14 @@ from typing import (
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
|
from synapse.api.constants import ApprovalNoticeMedium
|
||||||
|
from synapse.api.errors import (
|
||||||
|
Codes,
|
||||||
|
InvalidClientTokenError,
|
||||||
|
LoginError,
|
||||||
|
NotApprovedError,
|
||||||
|
SynapseError,
|
||||||
|
)
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.urls import CLIENT_API_PREFIX
|
from synapse.api.urls import CLIENT_API_PREFIX
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
|
@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoginResponse(TypedDict, total=False):
|
class LoginResponse(TypedDict, total=False):
|
||||||
user_id: str
|
user_id: str
|
||||||
access_token: str
|
access_token: Optional[str]
|
||||||
home_server: str
|
home_server: str
|
||||||
expires_in_ms: Optional[int]
|
expires_in_ms: Optional[int]
|
||||||
refresh_token: Optional[str]
|
refresh_token: Optional[str]
|
||||||
device_id: str
|
device_id: Optional[str]
|
||||||
well_known: Optional[Dict[str, Any]]
|
well_known: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
|
||||||
hs.config.registration.refreshable_access_token_lifetime is not None
|
hs.config.registration.refreshable_access_token_lifetime is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Whether we need to check if the user has been approved or not.
|
||||||
|
self._require_approval = (
|
||||||
|
hs.config.experimental.msc3866.enabled
|
||||||
|
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||||
|
)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
|
if self._require_approval:
|
||||||
|
approved = await self.auth_handler.is_user_approved(result["user_id"])
|
||||||
|
if not approved:
|
||||||
|
raise NotApprovedError(
|
||||||
|
msg="This account is pending approval by a server administrator.",
|
||||||
|
approval_notice_medium=ApprovalNoticeMedium.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
well_known_data = self._well_known_builder.get_well_known()
|
well_known_data = self._well_known_builder.get_well_known()
|
||||||
if well_known_data:
|
if well_known_data:
|
||||||
result["well_known"] = well_known_data
|
result["well_known"] = well_known_data
|
||||||
|
@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
|
||||||
errcode=Codes.INVALID_PARAM,
|
errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._require_approval:
|
||||||
|
approved = await self.auth_handler.is_user_approved(user_id)
|
||||||
|
if not approved:
|
||||||
|
# If the user isn't approved (and needs to be) we won't allow them to
|
||||||
|
# actually log in, so we don't want to create a device/access token.
|
||||||
|
return LoginResponse(
|
||||||
|
user_id=user_id,
|
||||||
|
home_server=self.hs.hostname,
|
||||||
|
)
|
||||||
|
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
(
|
(
|
||||||
device_id,
|
device_id,
|
||||||
|
|
|
@ -47,7 +47,9 @@ class LoginTokenRequestServlet(RestServlet):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = client_patterns("/login/token$")
|
PATTERNS = client_patterns(
|
||||||
|
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -21,10 +21,15 @@ from twisted.web.server import Request
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.api.auth
|
import synapse.api.auth
|
||||||
import synapse.types
|
import synapse.types
|
||||||
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
from synapse.api.constants import (
|
||||||
|
APP_SERVICE_REGISTRATION_TYPE,
|
||||||
|
ApprovalNoticeMedium,
|
||||||
|
LoginType,
|
||||||
|
)
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
InteractiveAuthIncompleteError,
|
InteractiveAuthIncompleteError,
|
||||||
|
NotApprovedError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
ThreepidValidationError,
|
ThreepidValidationError,
|
||||||
UnrecognizedRequestError,
|
UnrecognizedRequestError,
|
||||||
|
@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
hs.config.registration.inhibit_user_in_use_error
|
hs.config.registration.inhibit_user_in_use_error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._require_approval = (
|
||||||
|
hs.config.experimental.msc3866.enabled
|
||||||
|
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||||
|
)
|
||||||
|
|
||||||
self._registration_flows = _calculate_registration_flows(
|
self._registration_flows = _calculate_registration_flows(
|
||||||
hs.config, self.auth_handler
|
hs.config, self.auth_handler
|
||||||
)
|
)
|
||||||
|
@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet):
|
||||||
access_token=return_dict.get("access_token"),
|
access_token=return_dict.get("access_token"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._require_approval:
|
||||||
|
raise NotApprovedError(
|
||||||
|
msg="This account needs to be approved by an administrator before it can be used.",
|
||||||
|
approval_notice_medium=ApprovalNoticeMedium.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
return 200, return_dict
|
return 200, return_dict
|
||||||
|
|
||||||
async def _do_appservice_registration(
|
async def _do_appservice_registration(
|
||||||
|
@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet):
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
if not params.get("inhibit_login", False):
|
# We don't want to log the user in if we're going to deny them access because
|
||||||
|
# they need to be approved first.
|
||||||
|
if not params.get("inhibit_login", False) and not self._require_approval:
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
(
|
(
|
||||||
|
|
|
@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
self._attempt_to_invalidate_cache(
|
self._attempt_to_invalidate_cache(
|
||||||
"get_rooms_for_user_with_stream_ordering", (user_id,)
|
"get_rooms_for_user_with_stream_ordering", (user_id,)
|
||||||
)
|
)
|
||||||
|
self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
|
||||||
|
|
||||||
# Purge other caches based on room state.
|
# Purge other caches based on room state.
|
||||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||||
|
|
|
@ -423,16 +423,18 @@ class EventsPersistenceStorageController:
|
||||||
for d in ret_vals:
|
for d in ret_vals:
|
||||||
replaced_events.update(d)
|
replaced_events.update(d)
|
||||||
|
|
||||||
events = []
|
persisted_events = []
|
||||||
for event, _ in events_and_contexts:
|
for event, _ in events_and_contexts:
|
||||||
existing_event_id = replaced_events.get(event.event_id)
|
existing_event_id = replaced_events.get(event.event_id)
|
||||||
if existing_event_id:
|
if existing_event_id:
|
||||||
events.append(await self.main_store.get_event(existing_event_id))
|
persisted_events.append(
|
||||||
|
await self.main_store.get_event(existing_event_id)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
events.append(event)
|
persisted_events.append(event)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
events,
|
persisted_events,
|
||||||
self.main_store.get_room_max_token(),
|
self.main_store.get_room_max_token(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -529,7 +529,18 @@ class StateStorageController:
|
||||||
)
|
)
|
||||||
return state_map.get(key)
|
return state_map.get(key)
|
||||||
|
|
||||||
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
|
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||||
|
"""Get current hosts in room based on current state.
|
||||||
|
|
||||||
|
Blocks until we have full state for the given room. This only happens for rooms
|
||||||
|
with partial state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||||
|
|
||||||
|
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||||
|
|
||||||
|
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
||||||
"""Get current hosts in room based on current state.
|
"""Get current hosts in room based on current state.
|
||||||
|
|
||||||
Blocks until we have full state for the given room. This only happens for rooms
|
Blocks until we have full state for the given room. This only happens for rooms
|
||||||
|
@ -542,11 +553,11 @@ class StateStorageController:
|
||||||
|
|
||||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||||
|
|
||||||
return await self.stores.main.get_current_hosts_in_room(room_id)
|
return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
|
||||||
|
|
||||||
async def get_current_hosts_in_room_or_partial_state_approximation(
|
async def get_current_hosts_in_room_or_partial_state_approximation(
|
||||||
self, room_id: str
|
self, room_id: str
|
||||||
) -> Sequence[str]:
|
) -> Collection[str]:
|
||||||
"""Get approximation of current hosts in room based on current state.
|
"""Get approximation of current hosts in room based on current state.
|
||||||
|
|
||||||
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
|
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
|
||||||
|
@ -566,14 +577,9 @@ class StateStorageController:
|
||||||
)
|
)
|
||||||
|
|
||||||
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
|
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
|
||||||
hosts_from_state_set = set(hosts_from_state)
|
|
||||||
|
|
||||||
# First take the list of hosts based on the current state.
|
hosts = set(hosts_at_join)
|
||||||
# For rooms with partial state, this will be missing most hosts.
|
hosts.update(hosts_from_state)
|
||||||
hosts = list(hosts_from_state)
|
|
||||||
# Then add in the list of hosts in the room at the time we joined.
|
|
||||||
# This will be an empty list for rooms with full state.
|
|
||||||
hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
|
|
||||||
|
|
||||||
return hosts
|
return hosts
|
||||||
|
|
||||||
|
|
|
@ -1141,17 +1141,57 @@ class DatabasePool:
|
||||||
desc: str = "simple_upsert",
|
desc: str = "simple_upsert",
|
||||||
lock: bool = True,
|
lock: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Insert a row with values + insertion_values; on conflict, update with values.
|
||||||
|
|
||||||
`lock` should generally be set to True (the default), but can be set
|
All of our supported databases accept the nonstandard "upsert" statement in
|
||||||
to False if either of the following are true:
|
their dialect of SQL. We call this a "native upsert". The syntax looks roughly
|
||||||
1. there is a UNIQUE INDEX on the key columns. In this case a conflict
|
like:
|
||||||
will cause an IntegrityError in which case this function will retry
|
|
||||||
the update.
|
INSERT INTO table VALUES (values + insertion_values)
|
||||||
2. we somehow know that we are the only thread which will be updating
|
ON CONFLICT (keyvalues)
|
||||||
this table.
|
DO UPDATE SET (values); -- overwrite `values` columns only
|
||||||
As an additional note, this parameter only matters for old SQLite versions
|
|
||||||
because we will use native upserts otherwise.
|
If (values) is empty, the resulting query is slighlty simpler:
|
||||||
|
|
||||||
|
INSERT INTO table VALUES (insertion_values)
|
||||||
|
ON CONFLICT (keyvalues)
|
||||||
|
DO NOTHING; -- do not overwrite any columns
|
||||||
|
|
||||||
|
This function is a helper to build such queries.
|
||||||
|
|
||||||
|
In order for upserts to make sense, the database must be able to determine when
|
||||||
|
an upsert CONFLICTs with an existing row. Postgres and SQLite ensure this by
|
||||||
|
requiring that a unique index exist on the column names used to detect a
|
||||||
|
conflict (i.e. `keyvalues.keys()`).
|
||||||
|
|
||||||
|
If there is no such index, we can "emulate" an upsert with a SELECT followed
|
||||||
|
by either an INSERT or an UPDATE. This is unsafe: we cannot make the same
|
||||||
|
atomicity guarantees that a native upsert can and are very vulnerable to races
|
||||||
|
and crashes. Therefore if we wish to upsert without an appropriate unique index,
|
||||||
|
we must either:
|
||||||
|
|
||||||
|
1. Acquire a table-level lock before the emulated upsert (`lock=True`), or
|
||||||
|
2. VERY CAREFULLY ensure that we are the only thread and worker which will be
|
||||||
|
writing to this table, in which case we can proceed without a lock
|
||||||
|
(`lock=False`).
|
||||||
|
|
||||||
|
Generally speaking, you should use `lock=True`. If the table in question has a
|
||||||
|
unique index[*], this class will use a native upsert (which is atomic and so can
|
||||||
|
ignore the `lock` argument). Otherwise this class will use an emulated upsert,
|
||||||
|
in which case we want the safer option unless we been VERY CAREFUL.
|
||||||
|
|
||||||
|
[*]: Some tables have unique indices added to them in the background. Those
|
||||||
|
tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES,
|
||||||
|
where `T` maps to the background update that adds a unique index to `T`.
|
||||||
|
This dictionary is maintained by hand.
|
||||||
|
|
||||||
|
At runtime, we constantly check to see if each of these background updates
|
||||||
|
has run. If so, we deem the coresponding table safe to upsert into, because
|
||||||
|
we can now use a native insert to do so. If not, we deem the table unsafe
|
||||||
|
to upsert into and require an emulated upsert.
|
||||||
|
|
||||||
|
Tables that do not appear in this dictionary are assumed to have an
|
||||||
|
appropriate unique index and therefore be safe to upsert into.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table: The table to upsert into
|
table: The table to upsert into
|
||||||
|
|
|
@ -203,6 +203,7 @@ class DataStore(
|
||||||
deactivated: bool = False,
|
deactivated: bool = False,
|
||||||
order_by: str = UserSortOrder.USER_ID.value,
|
order_by: str = UserSortOrder.USER_ID.value,
|
||||||
direction: str = "f",
|
direction: str = "f",
|
||||||
|
approved: bool = True,
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""Function to retrieve a paginated list of users from
|
"""Function to retrieve a paginated list of users from
|
||||||
users list. This will return a json list of users and the
|
users list. This will return a json list of users and the
|
||||||
|
@ -217,6 +218,7 @@ class DataStore(
|
||||||
deactivated: whether to include deactivated users
|
deactivated: whether to include deactivated users
|
||||||
order_by: the sort order of the returned list
|
order_by: the sort order of the returned list
|
||||||
direction: sort ascending or descending
|
direction: sort ascending or descending
|
||||||
|
approved: whether to include approved users
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of a list of mappings from user to information and a count of total users.
|
A tuple of a list of mappings from user to information and a count of total users.
|
||||||
"""
|
"""
|
||||||
|
@ -249,6 +251,11 @@ class DataStore(
|
||||||
if not deactivated:
|
if not deactivated:
|
||||||
filters.append("deactivated = 0")
|
filters.append("deactivated = 0")
|
||||||
|
|
||||||
|
if not approved:
|
||||||
|
# We ignore NULL values for the approved flag because these should only
|
||||||
|
# be already existing users that we consider as already approved.
|
||||||
|
filters.append("approved IS FALSE")
|
||||||
|
|
||||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||||
|
|
||||||
sql_base = f"""
|
sql_base = f"""
|
||||||
|
@ -262,7 +269,7 @@ class DataStore(
|
||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
|
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
|
||||||
displayname, avatar_url, creation_ts * 1000 as creation_ts
|
displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
|
||||||
{sql_base}
|
{sql_base}
|
||||||
ORDER BY {order_by_column} {order}, u.name ASC
|
ORDER BY {order_by_column} {order}, u.name ASC
|
||||||
LIMIT ? OFFSET ?
|
LIMIT ? OFFSET ?
|
||||||
|
|
|
@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||||
(data.state_key,)
|
(data.state_key,)
|
||||||
)
|
)
|
||||||
|
self.get_rooms_for_user.invalidate((data.state_key,))
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown events stream row type %s" % (row.type,))
|
raise Exception("Unknown events stream row type %s" % (row.type,))
|
||||||
|
|
||||||
|
|
|
@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
return changes
|
return changes
|
||||||
|
|
||||||
|
async def get_device_list_changes_in_room(
|
||||||
|
self, room_id: str, min_stream_id: int
|
||||||
|
) -> Collection[Tuple[str, str]]:
|
||||||
|
"""Get all device list changes that happened in the room since the given
|
||||||
|
stream ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Collection of user ID/device ID tuples of all devices that have
|
||||||
|
changed
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
|
||||||
|
WHERE room_id = ? AND stream_id > ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_device_list_changes_in_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Collection[Tuple[str, str]]:
|
||||||
|
txn.execute(sql, (room_id, min_stream_id))
|
||||||
|
return cast(Collection[Tuple[str, str]], txn.fetchall())
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_device_list_changes_in_room",
|
||||||
|
get_device_list_changes_in_room_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_id: str,
|
device_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
stream_id: int,
|
stream_id: Optional[int],
|
||||||
hosts: Collection[str],
|
hosts: Collection[str],
|
||||||
context: Optional[Dict[str, str]],
|
context: Optional[Dict[str, str]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Queue the device update to be sent to the given set of hosts,
|
"""Queue the device update to be sent to the given set of hosts,
|
||||||
calculated from the room ID.
|
calculated from the room ID.
|
||||||
|
|
||||||
Marks the associated row in `device_lists_changes_in_room` as handled.
|
Marks the associated row in `device_lists_changes_in_room` as handled,
|
||||||
|
if `stream_id` is provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_device_list_outbound_pokes_txn(
|
def add_device_list_outbound_pokes_txn(
|
||||||
|
@ -1969,6 +1997,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream_id:
|
||||||
self.db_pool.simple_update_txn(
|
self.db_pool.simple_update_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_changes_in_room",
|
table="device_lists_changes_in_room",
|
||||||
|
@ -1995,3 +2024,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
add_device_list_outbound_pokes_txn,
|
add_device_list_outbound_pokes_txn,
|
||||||
stream_ids,
|
stream_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def add_remote_device_list_to_pending(
|
||||||
|
self, user_id: str, device_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Add a device list update to the table tracking remote device list
|
||||||
|
updates during partial joins.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||||
|
await self.db_pool.simple_upsert(
|
||||||
|
table="device_lists_remote_pending",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
values={"stream_id": stream_id},
|
||||||
|
desc="add_remote_device_list_to_pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pending_remote_device_list_updates_for_room(
|
||||||
|
self, room_id: str
|
||||||
|
) -> Collection[Tuple[str, str]]:
|
||||||
|
"""Get the set of remote device list updates from the pending table for
|
||||||
|
the room.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_device_stream_id = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="partial_state_rooms",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
retcol="device_lists_stream_id",
|
||||||
|
desc="get_pending_remote_device_list_updates_for_room_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, device_id FROM device_lists_remote_pending AS d
|
||||||
|
INNER JOIN current_state_events AS c ON
|
||||||
|
type = 'm.room.member'
|
||||||
|
AND state_key = user_id
|
||||||
|
AND membership = 'join'
|
||||||
|
WHERE
|
||||||
|
room_id = ? AND stream_id > ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_pending_remote_device_list_updates_for_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Collection[Tuple[str, str]]:
|
||||||
|
txn.execute(sql, (room_id, min_device_stream_id))
|
||||||
|
return cast(Collection[Tuple[str, str]], txn.fetchall())
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_pending_remote_device_list_updates_for_room",
|
||||||
|
get_pending_remote_device_list_updates_for_room_txn,
|
||||||
|
)
|
||||||
|
|
|
@ -73,13 +73,30 @@ pdus_pruned_from_federation_queue = Counter(
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int(
|
# Parameters controlling exponential backoff between backfill failures.
|
||||||
datetime.timedelta(days=7).total_seconds()
|
# After the first failure to backfill, we wait 2 hours before trying again. If the
|
||||||
)
|
# second attempt fails, we wait 4 hours before trying again. If the third attempt fails,
|
||||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int(
|
# we wait 8 hours before trying again, ... and so on.
|
||||||
datetime.timedelta(hours=1).total_seconds()
|
#
|
||||||
|
# Each successive backoff period is twice as long as the last. However we cap this
|
||||||
|
# period at a maximum of 2^8 = 256 hours: a little over 10 days. (This is the smallest
|
||||||
|
# power of 2 which yields a maximum backoff period of at least 7 days---which was the
|
||||||
|
# original maximum backoff period.) Even when we hit this cap, we will continue to
|
||||||
|
# make backfill attempts once every 10 days.
|
||||||
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS = 8
|
||||||
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS = int(
|
||||||
|
datetime.timedelta(hours=1).total_seconds() * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We need a cap on the power of 2 or else the backoff period
|
||||||
|
# 2^N * (milliseconds per hour)
|
||||||
|
# will overflow when calcuated within the database. We ensure overflow does not occur
|
||||||
|
# by checking that the largest backoff period fits in a 32-bit signed integer.
|
||||||
|
_LONGEST_BACKOFF_PERIOD_MILLISECONDS = (
|
||||||
|
2**BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS
|
||||||
|
) * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
|
||||||
|
assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1)
|
||||||
|
|
||||||
|
|
||||||
# All the info we need while iterating the DAG while backfilling
|
# All the info we need while iterating the DAG while backfilling
|
||||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||||
|
@ -726,17 +743,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
async def get_backfill_points_in_room(
|
async def get_backfill_points_in_room(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
current_depth: int,
|
||||||
|
limit: int,
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Gets the oldest events(backwards extremities) in the room along with the
|
Get the backward extremities to backfill from in the room along with the
|
||||||
approximate depth. Sorted by depth, highest to lowest (descending).
|
approximate depth.
|
||||||
|
|
||||||
|
Only returns events that are at a depth lower than or
|
||||||
|
equal to the `current_depth`. Sorted by depth, highest to lowest (descending)
|
||||||
|
so the closest events to the `current_depth` are first in the list.
|
||||||
|
|
||||||
|
We ignore extremities that are newer than the user's current scroll position
|
||||||
|
(ie, those with depth greater than `current_depth`) as:
|
||||||
|
1. we don't really care about getting events that have happened
|
||||||
|
after our current position; and
|
||||||
|
2. by the nature of paginating and scrolling back, we have likely
|
||||||
|
previously tried and failed to backfill from that extremity, so
|
||||||
|
to avoid getting "stuck" requesting the same backfill repeatedly
|
||||||
|
we drop those extremities.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id: Room where we want to find the oldest events
|
room_id: Room where we want to find the oldest events
|
||||||
|
current_depth: The depth at the user's current scrollback position
|
||||||
|
limit: The max number of backfill points to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
||||||
(descending)
|
(descending) so the closest events to the `current_depth` are first
|
||||||
|
in the list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_backfill_points_in_room_txn(
|
def get_backfill_points_in_room_txn(
|
||||||
|
@ -749,7 +784,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# persisted in our database yet (meaning we don't know their depth
|
# persisted in our database yet (meaning we don't know their depth
|
||||||
# specifically). So we need to look for the approximate depth from
|
# specifically). So we need to look for the approximate depth from
|
||||||
# the events connected to the current backwards extremeties.
|
# the events connected to the current backwards extremeties.
|
||||||
sql = """
|
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
least_function = "LEAST"
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
least_function = "MIN"
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown database engine")
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
SELECT backward_extrem.event_id, event.depth FROM events AS event
|
SELECT backward_extrem.event_id, event.depth FROM events AS event
|
||||||
/**
|
/**
|
||||||
* Get the edge connections from the event_edges table
|
* Get the edge connections from the event_edges table
|
||||||
|
@ -784,6 +827,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
* necessarily safe to assume that it will have been completed.
|
* necessarily safe to assume that it will have been completed.
|
||||||
*/
|
*/
|
||||||
AND edge.is_state is ? /* False */
|
AND edge.is_state is ? /* False */
|
||||||
|
/**
|
||||||
|
* We only want backwards extremities that are older than or at
|
||||||
|
* the same position of the given `current_depth` (where older
|
||||||
|
* means less than the given depth) because we're looking backwards
|
||||||
|
* from the `current_depth` when backfilling.
|
||||||
|
*
|
||||||
|
* current_depth (ignore events that come after this, ignore 2-4)
|
||||||
|
* |
|
||||||
|
* ▼
|
||||||
|
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
|
||||||
|
*/
|
||||||
|
AND event.depth <= ? /* current_depth */
|
||||||
/**
|
/**
|
||||||
* Exponential back-off (up to the upper bound) so we don't retry the
|
* Exponential back-off (up to the upper bound) so we don't retry the
|
||||||
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
||||||
|
@ -795,31 +850,31 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
*/
|
*/
|
||||||
AND (
|
AND (
|
||||||
failed_backfill_attempt_info.event_id IS NULL
|
failed_backfill_attempt_info.event_id IS NULL
|
||||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
|
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
|
||||||
|
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
|
||||||
|
* ? /* step */
|
||||||
|
)
|
||||||
)
|
)
|
||||||
/**
|
/**
|
||||||
* Sort from highest to the lowest depth. Then tie-break on
|
* Sort from highest (closest to the `current_depth`) to the lowest depth
|
||||||
* alphabetical order of the event_ids so we get a consistent
|
* because the closest are most relevant to backfill from first.
|
||||||
* ordering which is nice when asserting things in tests.
|
* Then tie-break on alphabetical order of the event_ids so we get a
|
||||||
|
* consistent ordering which is nice when asserting things in tests.
|
||||||
*/
|
*/
|
||||||
ORDER BY event.depth DESC, backward_extrem.event_id DESC
|
ORDER BY event.depth DESC, backward_extrem.event_id DESC
|
||||||
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
|
||||||
least_function = "least"
|
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
||||||
least_function = "min"
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown database engine")
|
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql % (least_function,),
|
sql,
|
||||||
(
|
(
|
||||||
room_id,
|
room_id,
|
||||||
False,
|
False,
|
||||||
|
current_depth,
|
||||||
self._clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||||
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
|
||||||
|
limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -835,24 +890,47 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
async def get_insertion_event_backward_extremities_in_room(
|
async def get_insertion_event_backward_extremities_in_room(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
current_depth: int,
|
||||||
|
limit: int,
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Get the insertion events we know about that we haven't backfilled yet
|
Get the insertion events we know about that we haven't backfilled yet
|
||||||
along with the approximate depth. Sorted by depth, highest to lowest
|
along with the approximate depth. Only returns insertion events that are
|
||||||
(descending).
|
at a depth lower than or equal to the `current_depth`. Sorted by depth,
|
||||||
|
highest to lowest (descending) so the closest events to the
|
||||||
|
`current_depth` are first in the list.
|
||||||
|
|
||||||
|
We ignore insertion events that are newer than the user's current scroll
|
||||||
|
position (ie, those with depth greater than `current_depth`) as:
|
||||||
|
1. we don't really care about getting events that have happened
|
||||||
|
after our current position; and
|
||||||
|
2. by the nature of paginating and scrolling back, we have likely
|
||||||
|
previously tried and failed to backfill from that insertion event, so
|
||||||
|
to avoid getting "stuck" requesting the same backfill repeatedly
|
||||||
|
we drop those insertion event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id: Room where we want to find the oldest events
|
room_id: Room where we want to find the oldest events
|
||||||
|
current_depth: The depth at the user's current scrollback position
|
||||||
|
limit: The max number of insertion event extremities to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
||||||
(descending)
|
(descending) so the closest events to the `current_depth` are first
|
||||||
|
in the list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_insertion_event_backward_extremities_in_room_txn(
|
def get_insertion_event_backward_extremities_in_room_txn(
|
||||||
txn: LoggingTransaction, room_id: str
|
txn: LoggingTransaction, room_id: str
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
sql = """
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
least_function = "LEAST"
|
||||||
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
least_function = "MIN"
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown database engine")
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
insertion_event_extremity.event_id, event.depth
|
insertion_event_extremity.event_id, event.depth
|
||||||
/* We only want insertion events that are also marked as backwards extremities */
|
/* We only want insertion events that are also marked as backwards extremities */
|
||||||
|
@ -869,6 +947,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id
|
AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id
|
||||||
WHERE
|
WHERE
|
||||||
insertion_event_extremity.room_id = ?
|
insertion_event_extremity.room_id = ?
|
||||||
|
/**
|
||||||
|
* We only want extremities that are older than or at
|
||||||
|
* the same position of the given `current_depth` (where older
|
||||||
|
* means less than the given depth) because we're looking backwards
|
||||||
|
* from the `current_depth` when backfilling.
|
||||||
|
*
|
||||||
|
* current_depth (ignore events that come after this, ignore 2-4)
|
||||||
|
* |
|
||||||
|
* ▼
|
||||||
|
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
|
||||||
|
*/
|
||||||
|
AND event.depth <= ? /* current_depth */
|
||||||
/**
|
/**
|
||||||
* Exponential back-off (up to the upper bound) so we don't retry the
|
* Exponential back-off (up to the upper bound) so we don't retry the
|
||||||
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc
|
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc
|
||||||
|
@ -880,30 +970,30 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
*/
|
*/
|
||||||
AND (
|
AND (
|
||||||
failed_backfill_attempt_info.event_id IS NULL
|
failed_backfill_attempt_info.event_id IS NULL
|
||||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
|
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
|
||||||
|
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
|
||||||
|
* ? /* step */
|
||||||
|
)
|
||||||
)
|
)
|
||||||
/**
|
/**
|
||||||
* Sort from highest to the lowest depth. Then tie-break on
|
* Sort from highest (closest to the `current_depth`) to the lowest depth
|
||||||
* alphabetical order of the event_ids so we get a consistent
|
* because the closest are most relevant to backfill from first.
|
||||||
* ordering which is nice when asserting things in tests.
|
* Then tie-break on alphabetical order of the event_ids so we get a
|
||||||
|
* consistent ordering which is nice when asserting things in tests.
|
||||||
*/
|
*/
|
||||||
ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC
|
ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC
|
||||||
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
|
||||||
least_function = "least"
|
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
||||||
least_function = "min"
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown database engine")
|
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql % (least_function,),
|
sql,
|
||||||
(
|
(
|
||||||
room_id,
|
room_id,
|
||||||
|
current_depth,
|
||||||
self._clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||||
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
|
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
|
||||||
|
limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
|
|
|
@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> NotifCounts:
|
) -> NotifCounts:
|
||||||
# Get the stream ordering of the user's latest receipt in the room.
|
# Get the stream ordering of the user's latest receipt in the room.
|
||||||
result = self.get_last_receipt_for_user_txn(
|
result = self.get_last_unthreaded_receipt_for_user_txn(
|
||||||
txn,
|
txn,
|
||||||
user_id,
|
user_id,
|
||||||
room_id,
|
room_id,
|
||||||
receipt_types=(
|
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||||
ReceiptTypes.READ,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
receipt_types_clause, args = make_in_list_sql_clause(
|
receipt_types_clause, args = make_in_list_sql_clause(
|
||||||
self.database_engine,
|
self.database_engine,
|
||||||
"receipt_type",
|
"receipt_type",
|
||||||
(
|
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||||
ReceiptTypes.READ,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
|
@ -1074,7 +1068,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
limit,
|
limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
rows = txn.fetchall()
|
rows = cast(List[Tuple[int, str, str, int]], txn.fetchall())
|
||||||
|
|
||||||
# For each new read receipt we delete push actions from before it and
|
# For each new read receipt we delete push actions from before it and
|
||||||
# recalculate the summary.
|
# recalculate the summary.
|
||||||
|
@ -1119,18 +1113,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
# We always update `event_push_summary_last_receipt_stream_id` to
|
# We always update `event_push_summary_last_receipt_stream_id` to
|
||||||
# ensure that we don't rescan the same receipts for remote users.
|
# ensure that we don't rescan the same receipts for remote users.
|
||||||
|
|
||||||
upper_limit = max_receipts_stream_id
|
receipts_last_processed_stream_id = max_receipts_stream_id
|
||||||
if len(rows) >= limit:
|
if len(rows) >= limit:
|
||||||
# If we pulled out a limited number of rows we only update the
|
# If we pulled out a limited number of rows we only update the
|
||||||
# position to the last receipt we processed, so we continue
|
# position to the last receipt we processed, so we continue
|
||||||
# processing the rest next iteration.
|
# processing the rest next iteration.
|
||||||
upper_limit = rows[-1][0]
|
receipts_last_processed_stream_id = rows[-1][0]
|
||||||
|
|
||||||
self.db_pool.simple_update_txn(
|
self.db_pool.simple_update_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_push_summary_last_receipt_stream_id",
|
table="event_push_summary_last_receipt_stream_id",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
updatevalues={"stream_id": upper_limit},
|
updatevalues={"stream_id": receipts_last_processed_stream_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
return len(rows) < limit
|
return len(rows) < limit
|
||||||
|
|
|
@ -2134,13 +2134,13 @@ class PersistEventsStore:
|
||||||
appear in events_and_context.
|
appear in events_and_context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Only non outlier events will have push actions associated with them,
|
# Only notifiable events will have push actions associated with them,
|
||||||
# so let's filter them out. (This makes joining large rooms faster, as
|
# so let's filter them out. (This makes joining large rooms faster, as
|
||||||
# these queries took seconds to process all the state events).
|
# these queries took seconds to process all the state events).
|
||||||
non_outlier_events = [
|
notifiable_events = [
|
||||||
event
|
event
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
if not event.internal_metadata.is_outlier()
|
if event.internal_metadata.is_notifiable()
|
||||||
]
|
]
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -2153,7 +2153,7 @@ class PersistEventsStore:
|
||||||
WHERE event_id = ?
|
WHERE event_id = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if non_outlier_events:
|
if notifiable_events:
|
||||||
txn.execute_batch(
|
txn.execute_batch(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
|
@ -2163,7 +2163,7 @@ class PersistEventsStore:
|
||||||
event.depth,
|
event.depth,
|
||||||
event.event_id,
|
event.event_id,
|
||||||
)
|
)
|
||||||
for event in non_outlier_events
|
for event in notifiable_events
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
"""Get the current max stream ID for receipts stream"""
|
"""Get the current max stream ID for receipts stream"""
|
||||||
return self._receipts_id_gen.get_current_token()
|
return self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
async def get_last_receipt_event_id_for_user(
|
def get_last_unthreaded_receipt_for_user_txn(
|
||||||
self, user_id: str, room_id: str, receipt_types: Collection[str]
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Fetch the event ID for the latest receipt in a room with one of the given receipt types.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user to fetch receipts for.
|
|
||||||
room_id: The room ID to fetch the receipt for.
|
|
||||||
receipt_type: The receipt types to fetch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The latest receipt, if one exists.
|
|
||||||
"""
|
|
||||||
result = await self.db_pool.runInteraction(
|
|
||||||
"get_last_receipt_event_id_for_user",
|
|
||||||
self.get_last_receipt_for_user_txn,
|
|
||||||
user_id,
|
|
||||||
room_id,
|
|
||||||
receipt_types,
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
event_id, _ = result
|
|
||||||
return event_id
|
|
||||||
|
|
||||||
def get_last_receipt_for_user_txn(
|
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
receipt_types: Collection[str],
|
receipt_types: Collection[str],
|
||||||
) -> Optional[Tuple[str, int]]:
|
) -> Optional[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Fetch the event ID and stream_ordering for the latest receipt in a room
|
Fetch the event ID and stream_ordering for the latest unthreaded receipt
|
||||||
with one of the given receipt types.
|
in a room with one of the given receipt types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user to fetch receipts for.
|
user_id: The user to fetch receipts for.
|
||||||
room_id: The room ID to fetch the receipt for.
|
room_id: The room ID to fetch the receipt for.
|
||||||
receipt_type: The receipt types to fetch.
|
receipt_types: The receipt types to fetch.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The event ID and stream ordering of the latest receipt, if one exists.
|
The event ID and stream ordering of the latest receipt, if one exists.
|
||||||
|
@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
WHERE {clause}
|
WHERE {clause}
|
||||||
AND user_id = ?
|
AND user_id = ?
|
||||||
AND room_id = ?
|
AND room_id = ?
|
||||||
|
AND thread_id IS NULL
|
||||||
ORDER BY stream_ordering DESC
|
ORDER BY stream_ordering DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
@cached()
|
@cached()
|
||||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Deprecated: use get_userinfo_by_id instead"""
|
"""Deprecated: use get_userinfo_by_id instead"""
|
||||||
return await self.db_pool.simple_select_one(
|
|
||||||
table="users",
|
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||||
keyvalues={"name": user_id},
|
# We could technically use simple_select_one here, but it would not perform
|
||||||
retcols=[
|
# the COALESCEs (unless hacked into the column names), which could yield
|
||||||
"name",
|
# confusing results.
|
||||||
"password_hash",
|
txn.execute(
|
||||||
"is_guest",
|
"""
|
||||||
"admin",
|
SELECT
|
||||||
"consent_version",
|
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
||||||
"consent_ts",
|
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||||
"consent_server_notice_sent",
|
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||||
"appservice_id",
|
COALESCE(approved, TRUE) AS approved
|
||||||
"creation_ts",
|
FROM users
|
||||||
"user_type",
|
WHERE name = ?
|
||||||
"deactivated",
|
""",
|
||||||
"shadow_banned",
|
(user_id,),
|
||||||
],
|
|
||||||
allow_none=True,
|
|
||||||
desc="get_user_by_id",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
if len(rows) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return rows[0]
|
||||||
|
|
||||||
|
row = await self.db_pool.runInteraction(
|
||||||
|
desc="get_user_by_id",
|
||||||
|
func=get_user_by_id_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
# If we're using SQLite our boolean values will be integers. Because we
|
||||||
|
# present some of this data as is to e.g. server admins via REST APIs, we
|
||||||
|
# want to make sure we're returning the right type of data.
|
||||||
|
# Note: when adding a column name to this list, be wary of NULLable columns,
|
||||||
|
# since NULL values will be turned into False.
|
||||||
|
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
|
||||||
|
for column in boolean_columns:
|
||||||
|
if not isinstance(row[column], bool):
|
||||||
|
row[column] = bool(row[column])
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||||
"""Get a UserInfo object for a user by user ID.
|
"""Get a UserInfo object for a user by user ID.
|
||||||
|
|
||||||
|
@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
return res if res else False
|
return res if res else False
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def is_user_approved(self, user_id: str) -> bool:
|
||||||
|
"""Checks if a user is approved and therefore can be allowed to log in.
|
||||||
|
|
||||||
|
If the user's 'approved' column is NULL, we consider it as true given it means
|
||||||
|
the user was registered when support for an approval flow was either disabled
|
||||||
|
or nonexistent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user to check the approval status of.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean that is True if the user is approved, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_user_approved_txn(txn: LoggingTransaction) -> bool:
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
|
||||||
|
""",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
# We cast to bool because the value returned by the database engine might
|
||||||
|
# be an integer if we're using SQLite.
|
||||||
|
return bool(rows[0]["approved"])
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
desc="is_user_pending_approval",
|
||||||
|
func=is_user_approved_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
|
def update_user_approval_status_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, approved: bool
|
||||||
|
) -> None:
|
||||||
|
"""Set the user's 'approved' flag to the given value.
|
||||||
|
|
||||||
|
The boolean is turned into an int because the column is a smallint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txn: the current database transaction.
|
||||||
|
user_id: the user to update the flag for.
|
||||||
|
approved: the value to set the flag to.
|
||||||
|
"""
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
updatevalues={"approved": approved},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalidate the caches of methods that read the value of the 'approved' flag.
|
||||||
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
|
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||||
|
|
||||||
|
# If support for MSC3866 is enabled and configured to require approval for new
|
||||||
|
# account, we will create new users with an 'approved' flag set to false.
|
||||||
|
self._require_approval = (
|
||||||
|
hs.config.experimental.msc3866.enabled
|
||||||
|
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||||
|
)
|
||||||
|
|
||||||
async def add_access_token_to_user(
|
async def add_access_token_to_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
admin: bool = False,
|
admin: bool = False,
|
||||||
user_type: Optional[str] = None,
|
user_type: Optional[str] = None,
|
||||||
shadow_banned: bool = False,
|
shadow_banned: bool = False,
|
||||||
|
approved: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
|
@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
or None for a normal user.
|
or None for a normal user.
|
||||||
shadow_banned: Whether the user is shadow-banned, i.e. they may be
|
shadow_banned: Whether the user is shadow-banned, i.e. they may be
|
||||||
told their requests succeeded but we ignore them.
|
told their requests succeeded but we ignore them.
|
||||||
|
approved: Whether to consider the user has already been approved by an
|
||||||
|
administrator.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
|
@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
admin,
|
admin,
|
||||||
user_type,
|
user_type,
|
||||||
shadow_banned,
|
shadow_banned,
|
||||||
|
approved,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _register_user(
|
def _register_user(
|
||||||
|
@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
admin: bool,
|
admin: bool,
|
||||||
user_type: Optional[str],
|
user_type: Optional[str],
|
||||||
shadow_banned: bool,
|
shadow_banned: bool,
|
||||||
|
approved: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
user_id_obj = UserID.from_string(user_id)
|
user_id_obj = UserID.from_string(user_id)
|
||||||
|
|
||||||
now = int(self._clock.time())
|
now = int(self._clock.time())
|
||||||
|
|
||||||
|
user_approved = approved or not self._require_approval
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
# Ensure that the guest user actually exists
|
# Ensure that the guest user actually exists
|
||||||
|
@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
"admin": 1 if admin else 0,
|
"admin": 1 if admin else 0,
|
||||||
"user_type": user_type,
|
"user_type": user_type,
|
||||||
"shadow_banned": shadow_banned,
|
"shadow_banned": shadow_banned,
|
||||||
|
"approved": user_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
"admin": 1 if admin else 0,
|
"admin": 1 if admin else 0,
|
||||||
"user_type": user_type,
|
"user_type": user_type,
|
||||||
"shadow_banned": shadow_banned,
|
"shadow_banned": shadow_banned,
|
||||||
|
"approved": user_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
start_or_continue_validation_session_txn,
|
start_or_continue_validation_session_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_user_approval_status(
|
||||||
|
self, user_id: UserID, approved: bool
|
||||||
|
) -> None:
|
||||||
|
"""Set the user's 'approved' flag to the given value.
|
||||||
|
|
||||||
|
The boolean will be turned into an int (in update_user_approval_status_txn)
|
||||||
|
because the column is a smallint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user to update the flag for.
|
||||||
|
approved: the value to set the flag to.
|
||||||
|
"""
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"update_user_approval_status",
|
||||||
|
self.update_user_approval_status_txn,
|
||||||
|
user_id.to_string(),
|
||||||
|
approved,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
|
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
|
||||||
|
|
||||||
|
# We now delete anything from `device_lists_remote_pending` with a
|
||||||
|
# stream ID less than the minimum
|
||||||
|
# `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
|
||||||
|
device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="partial_state_rooms",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="MIN(device_lists_stream_id)",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if device_lists_stream_id is None:
|
||||||
|
# There are no rooms being currently partially joined, so we delete everything.
|
||||||
|
txn.execute("DELETE FROM device_lists_remote_pending")
|
||||||
|
else:
|
||||||
|
sql = """
|
||||||
|
DELETE FROM device_lists_remote_pending
|
||||||
|
WHERE stream_id <= ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (device_lists_stream_id,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def is_partial_state_room(self, room_id: str) -> bool:
|
async def is_partial_state_room(self, room_id: str) -> bool:
|
||||||
"""Checks if this room has partial state.
|
"""Checks if this room has partial state.
|
||||||
|
@ -1236,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
return entry is not None
|
return entry is not None
|
||||||
|
|
||||||
|
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||||
|
self, room_id: str
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
"""Get the event ID of the initial join that started the partial
|
||||||
|
join, and the device list stream ID at the point we started the partial
|
||||||
|
join.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await self.db_pool.simple_select_one(
|
||||||
|
table="partial_state_rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=("join_event_id", "device_lists_stream_id"),
|
||||||
|
desc="get_join_event_id_for_partial_state",
|
||||||
|
)
|
||||||
|
return result["join_event_id"], result["device_lists_stream_id"]
|
||||||
|
|
||||||
|
|
||||||
class _BackgroundUpdates:
|
class _BackgroundUpdates:
|
||||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Callable,
|
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
|
@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||||
from synapse.util.cancellation import cancellable
|
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
@ -148,42 +146,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
async def get_users_in_room(self, room_id: str) -> List[str]:
|
async def get_users_in_room(self, room_id: str) -> List[str]:
|
||||||
"""
|
"""Returns a list of users in the room.
|
||||||
Returns a list of users in the room sorted by longest in the room first
|
|
||||||
(aka. with the lowest depth). This is done to match the sort in
|
|
||||||
`get_current_hosts_in_room()` and so we can re-use the cache but it's
|
|
||||||
not horrible to have here either.
|
|
||||||
|
|
||||||
Uses `m.room.member`s in the room state at the current forward extremities to
|
|
||||||
determine which users are in the room.
|
|
||||||
|
|
||||||
Will return inaccurate results for rooms with partial state, since the state for
|
Will return inaccurate results for rooms with partial state, since the state for
|
||||||
the forward extremities of those rooms will exclude most members. We may also
|
the forward extremities of those rooms will exclude most members. We may also
|
||||||
calculate room state incorrectly for such rooms and believe that a member is or
|
calculate room state incorrectly for such rooms and believe that a member is or
|
||||||
is not in the room when the opposite is true.
|
is not in the room when the opposite is true.
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.simple_select_onecol(
|
||||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
table="current_state_events",
|
||||||
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"room_id": room_id,
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
},
|
||||||
|
retcol="state_key",
|
||||||
|
desc="get_users_in_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
||||||
"""
|
"""Returns a list of users in the room."""
|
||||||
Returns a list of users in the room sorted by longest in the room first
|
|
||||||
(aka. with the lowest depth). This is done to match the sort in
|
|
||||||
`get_current_hosts_in_room()` and so we can re-use the cache but it's
|
|
||||||
not horrible to have here either.
|
|
||||||
"""
|
|
||||||
sql = """
|
|
||||||
SELECT c.state_key FROM current_state_events as c
|
|
||||||
/* Get the depth of the event from the events table */
|
|
||||||
INNER JOIN events AS e USING (event_id)
|
|
||||||
WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
|
|
||||||
/* Sorted by lowest depth first */
|
|
||||||
ORDER BY e.depth ASC;
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (room_id, Membership.JOIN))
|
return self.db_pool.simple_select_onecol_txn(
|
||||||
return [r[0] for r in txn]
|
txn,
|
||||||
|
table="current_state_events",
|
||||||
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"room_id": room_id,
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
},
|
||||||
|
retcol="state_key",
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_in_room_with_profile(
|
def get_user_in_room_with_profile(
|
||||||
|
@ -600,58 +593,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
for room_id, instance, stream_id in txn
|
for room_id, instance, stream_id in txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedList(
|
|
||||||
cached_method_name="get_rooms_for_user_with_stream_ordering",
|
|
||||||
list_name="user_ids",
|
|
||||||
)
|
|
||||||
async def get_rooms_for_users_with_stream_ordering(
|
|
||||||
self, user_ids: Collection[str]
|
|
||||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
|
||||||
"""A batched version of `get_rooms_for_user_with_stream_ordering`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Map from user_id to set of rooms that is currently in.
|
|
||||||
"""
|
|
||||||
return await self.db_pool.runInteraction(
|
|
||||||
"get_rooms_for_users_with_stream_ordering",
|
|
||||||
self._get_rooms_for_users_with_stream_ordering_txn,
|
|
||||||
user_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_rooms_for_users_with_stream_ordering_txn(
|
|
||||||
self, txn: LoggingTransaction, user_ids: Collection[str]
|
|
||||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
|
||||||
|
|
||||||
clause, args = make_in_list_sql_clause(
|
|
||||||
self.database_engine,
|
|
||||||
"c.state_key",
|
|
||||||
user_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = f"""
|
|
||||||
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
|
|
||||||
FROM current_state_events AS c
|
|
||||||
INNER JOIN events AS e USING (room_id, event_id)
|
|
||||||
WHERE
|
|
||||||
c.type = 'm.room.member'
|
|
||||||
AND c.membership = ?
|
|
||||||
AND {clause}
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, [Membership.JOIN] + args)
|
|
||||||
|
|
||||||
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
|
|
||||||
user_id: set() for user_id in user_ids
|
|
||||||
}
|
|
||||||
for user_id, room_id, instance, stream_id in txn:
|
|
||||||
result[user_id].add(
|
|
||||||
GetRoomsForUserWithStreamOrdering(
|
|
||||||
room_id, PersistedEventPosition(instance, stream_id)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {user_id: frozenset(v) for user_id, v in result.items()}
|
|
||||||
|
|
||||||
async def get_users_server_still_shares_room_with(
|
async def get_users_server_still_shares_room_with(
|
||||||
self, user_ids: Collection[str]
|
self, user_ids: Collection[str]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
|
@ -693,20 +634,69 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return {row[0] for row in txn}
|
return {row[0] for row in txn}
|
||||||
|
|
||||||
@cancellable
|
@cached(max_entries=500000, iterable=True)
|
||||||
async def get_rooms_for_user(
|
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
|
||||||
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|
|
||||||
) -> FrozenSet[str]:
|
|
||||||
"""Returns a set of room_ids the user is currently joined to.
|
"""Returns a set of room_ids the user is currently joined to.
|
||||||
|
|
||||||
If a remote user only returns rooms this server is currently
|
If a remote user only returns rooms this server is currently
|
||||||
participating in.
|
participating in.
|
||||||
"""
|
"""
|
||||||
rooms = await self.get_rooms_for_user_with_stream_ordering(
|
rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
|
||||||
user_id, on_invalidate=on_invalidate
|
(user_id,),
|
||||||
|
None,
|
||||||
|
update_metrics=False,
|
||||||
)
|
)
|
||||||
|
if rooms:
|
||||||
return frozenset(r.room_id for r in rooms)
|
return frozenset(r.room_id for r in rooms)
|
||||||
|
|
||||||
|
room_ids = await self.db_pool.simple_select_onecol(
|
||||||
|
table="current_state_events",
|
||||||
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
"state_key": user_id,
|
||||||
|
},
|
||||||
|
retcol="room_id",
|
||||||
|
desc="get_rooms_for_user",
|
||||||
|
)
|
||||||
|
|
||||||
|
return frozenset(room_ids)
|
||||||
|
|
||||||
|
@cachedList(
|
||||||
|
cached_method_name="get_rooms_for_user",
|
||||||
|
list_name="user_ids",
|
||||||
|
)
|
||||||
|
async def get_rooms_for_users(
|
||||||
|
self, user_ids: Collection[str]
|
||||||
|
) -> Dict[str, FrozenSet[str]]:
|
||||||
|
"""A batched version of `get_rooms_for_user`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Map from user_id to set of rooms that is currently in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
|
table="current_state_events",
|
||||||
|
column="state_key",
|
||||||
|
iterable=user_ids,
|
||||||
|
retcols=(
|
||||||
|
"state_key",
|
||||||
|
"room_id",
|
||||||
|
),
|
||||||
|
keyvalues={
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
},
|
||||||
|
desc="get_rooms_for_users",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
user_rooms[row["state_key"]].add(row["room_id"])
|
||||||
|
|
||||||
|
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
async def does_pair_of_users_share_a_room(
|
async def does_pair_of_users_share_a_room(
|
||||||
self, user_id: str, other_user_id: str
|
self, user_id: str, other_user_id: str
|
||||||
|
@ -936,7 +926,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@cached(iterable=True, max_entries=10000)
|
@cached(iterable=True, max_entries=10000)
|
||||||
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
|
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||||
|
"""Get current hosts in room based on current state."""
|
||||||
|
|
||||||
|
# First we check if we already have `get_users_in_room` in the cache, as
|
||||||
|
# we can just calculate result from that
|
||||||
|
users = self.get_users_in_room.cache.get_immediate(
|
||||||
|
(room_id,), None, update_metrics=False
|
||||||
|
)
|
||||||
|
if users is not None:
|
||||||
|
return {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
|
if isinstance(self.database_engine, Sqlite3Engine):
|
||||||
|
# If we're using SQLite then let's just always use
|
||||||
|
# `get_users_in_room` rather than funky SQL.
|
||||||
|
users = await self.get_users_in_room(room_id)
|
||||||
|
return {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
|
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||||
|
# joined users in `current_state_events` via regex.
|
||||||
|
|
||||||
|
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||||
|
FROM current_state_events
|
||||||
|
WHERE
|
||||||
|
type = 'm.room.member'
|
||||||
|
AND membership = 'join'
|
||||||
|
AND room_id = ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (room_id,))
|
||||||
|
return {d for d, in txn}
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(iterable=True, max_entries=10000)
|
||||||
|
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get current hosts in room based on current state.
|
Get current hosts in room based on current state.
|
||||||
|
|
||||||
|
@ -944,48 +971,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
longest is good because they're most likely to have anything we ask
|
longest is good because they're most likely to have anything we ask
|
||||||
about.
|
about.
|
||||||
|
|
||||||
Uses `m.room.member`s in the room state at the current forward extremities to
|
For SQLite the returned list is not ordered, as SQLite doesn't support
|
||||||
determine which hosts are in the room.
|
the appropriate SQL.
|
||||||
|
|
||||||
Will return inaccurate results for rooms with partial state, since the state for
|
Uses `m.room.member`s in the room state at the current forward
|
||||||
the forward extremities of those rooms will exclude most members. We may also
|
extremities to determine which hosts are in the room.
|
||||||
calculate room state incorrectly for such rooms and believe that a host is or
|
|
||||||
is not in the room when the opposite is true.
|
Will return inaccurate results for rooms with partial state, since the
|
||||||
|
state for the forward extremities of those rooms will exclude most
|
||||||
|
members. We may also calculate room state incorrectly for such rooms and
|
||||||
|
believe that a host is or is not in the room when the opposite is true.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns a list of servers sorted by longest in the room first. (aka.
|
Returns a list of servers sorted by longest in the room first. (aka.
|
||||||
sorted by join with the lowest depth first).
|
sorted by join with the lowest depth first).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First we check if we already have `get_users_in_room` in the cache, as
|
if isinstance(self.database_engine, Sqlite3Engine):
|
||||||
# we can just calculate result from that
|
|
||||||
users = self.get_users_in_room.cache.get_immediate(
|
|
||||||
(room_id,), None, update_metrics=False
|
|
||||||
)
|
|
||||||
if users is None and isinstance(self.database_engine, Sqlite3Engine):
|
|
||||||
# If we're using SQLite then let's just always use
|
# If we're using SQLite then let's just always use
|
||||||
# `get_users_in_room` rather than funky SQL.
|
# `get_users_in_room` rather than funky SQL.
|
||||||
users = await self.get_users_in_room(room_id)
|
|
||||||
|
|
||||||
if users is not None:
|
domains = await self.get_current_hosts_in_room(room_id)
|
||||||
# Because `users` is sorted from lowest -> highest depth, the list
|
return list(domains)
|
||||||
# of domains will also be sorted that way.
|
|
||||||
domains: List[str] = []
|
|
||||||
# We use a `Set` just for fast lookups
|
|
||||||
domain_set: Set[str] = set()
|
|
||||||
for u in users:
|
|
||||||
if ":" not in u:
|
|
||||||
continue
|
|
||||||
domain = get_domain_from_id(u)
|
|
||||||
if domain not in domain_set:
|
|
||||||
domain_set.add(domain)
|
|
||||||
domains.append(domain)
|
|
||||||
return domains
|
|
||||||
|
|
||||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||||
# joined users in `current_state_events` via regex.
|
# joined users in `current_state_events` via regex.
|
||||||
|
|
||||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
|
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
# Returns a list of servers currently joined in the room sorted by
|
# Returns a list of servers currently joined in the room sorted by
|
||||||
# longest in the room first (aka. with the lowest depth). The
|
# longest in the room first (aka. with the lowest depth). The
|
||||||
# heuristic of sorting by servers who have been in the room the
|
# heuristic of sorting by servers who have been in the room the
|
||||||
|
@ -1013,7 +1025,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return [d for d, in txn if d is not None]
|
return [d for d, in txn if d is not None]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_joined_hosts(
|
async def get_joined_hosts(
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Add a column to the users table to track whether the user needs to be approved by an
|
||||||
|
-- administrator.
|
||||||
|
-- A NULL column means the user was created before this feature was supported by Synapse,
|
||||||
|
-- and should be considered as TRUE.
|
||||||
|
ALTER TABLE users ADD COLUMN approved BOOLEAN;
|
|
@ -0,0 +1,28 @@
|
||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Stores remote device lists we have received for remote users while a partial
|
||||||
|
-- join is in progress.
|
||||||
|
--
|
||||||
|
-- This allows us to replay any device list updates if it turns out the remote
|
||||||
|
-- user was in the partially joined room
|
||||||
|
CREATE TABLE device_lists_remote_pending(
|
||||||
|
stream_id BIGINT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- We only keep the most recent update for a given user/device pair.
|
||||||
|
CREATE UNIQUE INDEX device_lists_remote_pending_user_device_id ON device_lists_remote_pending(user_id, device_id);
|
|
@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _should_ignore_runtime_requirement(req: Requirement) -> bool:
|
||||||
|
# This is a build-time dependency. Irritatingly, `poetry build` ignores the
|
||||||
|
# requirements listed in the [build-system] section of pyproject.toml, so in order
|
||||||
|
# to support `poetry install --no-dev` we have to mark it as a runtime dependency.
|
||||||
|
# See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds
|
||||||
|
# like the poetry authors don't consider this a bug?)
|
||||||
|
#
|
||||||
|
# In any case, workaround this by ignoring setuptools_rust here. (It might be
|
||||||
|
# slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for
|
||||||
|
# now let's do something quick and dirty.
|
||||||
|
if req.name == "setuptools_rust":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Dependency(NamedTuple):
|
class Dependency(NamedTuple):
|
||||||
requirement: Requirement
|
requirement: Requirement
|
||||||
must_be_installed: bool
|
must_be_installed: bool
|
||||||
|
@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]:
|
||||||
assert requirements is not None
|
assert requirements is not None
|
||||||
for raw_requirement in requirements:
|
for raw_requirement in requirements:
|
||||||
req = Requirement(raw_requirement)
|
req = Requirement(raw_requirement)
|
||||||
if _is_dev_dependency(req):
|
if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# https://packaging.pypa.io/en/latest/markers.html#usage notes that
|
# https://packaging.pypa.io/en/latest/markers.html#usage notes that
|
||||||
|
|
|
@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Blow away caches (supported room versions can only change due to a restart).
|
# Blow away caches (supported room versions can only change due to a restart).
|
||||||
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
|
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
|
||||||
|
self.store.get_rooms_for_user.invalidate_all()
|
||||||
self.get_success(self.store._get_event_cache.clear())
|
self.get_success(self.store._get_event_cache.clear())
|
||||||
self.store._event_ref.clear()
|
self.store._event_ref.clear()
|
||||||
|
|
||||||
|
|
|
@ -19,16 +19,18 @@ import frozendict
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.push import push_rule_evaluator
|
from synapse.push.bulk_push_rule_evaluator import _flatten_dict
|
||||||
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
|
from synapse.push.httppusher import tweaks_for_actions
|
||||||
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, register, room
|
from synapse.rest.client import login, register, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||||
from synapse.types import JsonDict
|
from synapse.synapse_rust.push import PushRuleEvaluator
|
||||||
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -41,7 +43,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
|
relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
|
||||||
relations_match_enabled: bool = False,
|
relations_match_enabled: bool = False,
|
||||||
) -> PushRuleEvaluatorForEvent:
|
) -> PushRuleEvaluator:
|
||||||
event = FrozenEvent(
|
event = FrozenEvent(
|
||||||
{
|
{
|
||||||
"event_id": "$event_id",
|
"event_id": "$event_id",
|
||||||
|
@ -56,12 +58,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
room_member_count = 0
|
room_member_count = 0
|
||||||
sender_power_level = 0
|
sender_power_level = 0
|
||||||
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
||||||
return PushRuleEvaluatorForEvent(
|
return PushRuleEvaluator(
|
||||||
event,
|
_flatten_dict(event),
|
||||||
room_member_count,
|
room_member_count,
|
||||||
sender_power_level,
|
sender_power_level,
|
||||||
power_levels,
|
power_levels.get("notifications", {}),
|
||||||
relations or set(),
|
relations or {},
|
||||||
relations_match_enabled,
|
relations_match_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -293,7 +295,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
push_rule_evaluator.tweaks_for_actions(actions),
|
tweaks_for_actions(actions),
|
||||||
{"sound": "default", "highlight": True},
|
{"sound": "default", "highlight": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -304,9 +306,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
evaluator = self._get_evaluator(
|
evaluator = self._get_evaluator(
|
||||||
{}, {"m.annotation": {("@user:test", "m.reaction")}}
|
{}, {"m.annotation": {("@user:test", "m.reaction")}}
|
||||||
)
|
)
|
||||||
condition = {"kind": "relation_match"}
|
|
||||||
# Oddly, an unknown condition always matches.
|
|
||||||
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
|
|
||||||
|
|
||||||
# A push rule evaluator with the experimental rule enabled.
|
# A push rule evaluator with the experimental rule enabled.
|
||||||
evaluator = self._get_evaluator(
|
evaluator = self._get_evaluator(
|
||||||
|
@ -439,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(users_with_push_actions), 0)
|
self.assertEqual(len(users_with_push_actions), 0)
|
||||||
|
|
||||||
|
|
||||||
|
class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
|
self.main_store = homeserver.get_datastores().main
|
||||||
|
|
||||||
|
self.user_id1 = self.register_user("user1", "password")
|
||||||
|
self.tok1 = self.login(self.user_id1, "password")
|
||||||
|
self.user_id2 = self.register_user("user2", "password")
|
||||||
|
self.tok2 = self.login(self.user_id2, "password")
|
||||||
|
|
||||||
|
self.room_id = self.helper.create_room_as(tok=self.tok1)
|
||||||
|
|
||||||
|
# We want to test history visibility works correctly.
|
||||||
|
self.helper.send_state(
|
||||||
|
self.room_id,
|
||||||
|
EventTypes.RoomHistoryVisibility,
|
||||||
|
{"history_visibility": HistoryVisibility.JOINED},
|
||||||
|
tok=self.tok1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_notif_count(self, user_id: str) -> int:
|
||||||
|
return self.get_success(
|
||||||
|
self.main_store.db_pool.simple_select_one_onecol(
|
||||||
|
table="event_push_actions",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="COALESCE(SUM(notif), 0)",
|
||||||
|
desc="get_staging_notif_count",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_plain_message(self) -> None:
|
||||||
|
"""Test that sending a normal message in a room will trigger a
|
||||||
|
notification
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Have user2 join the room and cle
|
||||||
|
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
|
||||||
|
|
||||||
|
# They start off with no notifications, but get them when messages are
|
||||||
|
# sent.
|
||||||
|
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||||
|
|
||||||
|
user1 = UserID.from_string(self.user_id1)
|
||||||
|
self.create_and_send_event(self.room_id, user1)
|
||||||
|
|
||||||
|
self.assertEqual(self.get_notif_count(self.user_id2), 1)
|
||||||
|
|
||||||
|
def test_delayed_message(self) -> None:
|
||||||
|
"""Test that a delayed message that was from before a user joined
|
||||||
|
doesn't cause a notification for the joined user.
|
||||||
|
"""
|
||||||
|
user1 = UserID.from_string(self.user_id1)
|
||||||
|
|
||||||
|
# Send a message before user2 joins
|
||||||
|
event_id1 = self.create_and_send_event(self.room_id, user1)
|
||||||
|
|
||||||
|
# Have user2 join the room
|
||||||
|
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
|
||||||
|
|
||||||
|
# They start off with no notifications
|
||||||
|
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||||
|
|
||||||
|
# Send another message that references the event before the join to
|
||||||
|
# simulate a "delayed" event
|
||||||
|
self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
|
||||||
|
|
||||||
|
# user2 should not be notified about it, because they can't see it.
|
||||||
|
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||||
|
|
|
@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
|
||||||
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.rest.client import devices, login, logout, profile, room, sync
|
from synapse.rest.client import devices, login, logout, profile, register, room, sync
|
||||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
_search_test(None, "foo", "user_id")
|
_search_test(None, "foo", "user_id")
|
||||||
_search_test(None, "bar", "user_id")
|
_search_test(None, "bar", "user_id")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_invalid_parameter(self) -> None:
|
def test_invalid_parameter(self) -> None:
|
||||||
"""
|
"""
|
||||||
If parameters are invalid, an error is returned.
|
If parameters are invalid, an error is returned.
|
||||||
|
@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# invalid approved
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?approved=not_bool",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||||
|
|
||||||
# unkown order_by
|
# unkown order_by
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
|
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
|
||||||
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
|
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_filter_out_approved(self) -> None:
|
||||||
|
"""Tests that the endpoint can filter out approved users."""
|
||||||
|
# Create our users.
|
||||||
|
self._create_users(2)
|
||||||
|
|
||||||
|
# Get the list of users.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
|
# Exclude the admin, because we don't want to accidentally un-approve the admin.
|
||||||
|
non_admin_user_ids = [
|
||||||
|
user["name"]
|
||||||
|
for user in channel.json_body["users"]
|
||||||
|
if user["name"] != self.admin_user
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
|
||||||
|
|
||||||
|
# Select a user and un-approve them. We do this rather than the other way around
|
||||||
|
# because, since these users are created by an admin, we consider them already
|
||||||
|
# approved.
|
||||||
|
not_approved_user = non_admin_user_ids[0]
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"/_synapse/admin/v2/users/{not_approved_user}",
|
||||||
|
{"approved": False},
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
|
# Now get the list of users again, this time filtering out approved users.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
self.url + "?approved=false",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, channel.result)
|
||||||
|
|
||||||
|
non_admin_user_ids = [
|
||||||
|
user["name"]
|
||||||
|
for user in channel.json_body["users"]
|
||||||
|
if user["name"] != self.admin_user
|
||||||
|
]
|
||||||
|
|
||||||
|
# We should only have our unapproved user now.
|
||||||
|
self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
|
||||||
|
self.assertEqual(not_approved_user, non_admin_user_ids[0])
|
||||||
|
|
||||||
def _order_test(
|
def _order_test(
|
||||||
self,
|
self,
|
||||||
expected_user_list: List[str],
|
expected_user_list: List[str],
|
||||||
|
@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
synapse.rest.admin.register_servlets,
|
synapse.rest.admin.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
# Ensure they're still alive
|
# Ensure they're still alive
|
||||||
self.assertEqual(0, channel.json_body["deactivated"])
|
self.assertEqual(0, channel.json_body["deactivated"])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_approve_account(self) -> None:
|
||||||
|
"""Tests that approving an account correctly sets the approved flag for the user."""
|
||||||
|
url = self.url_prefix % "@bob:test"
|
||||||
|
|
||||||
|
# Create the user using the client-server API since otherwise the user will be
|
||||||
|
# marked as approved automatically.
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{
|
||||||
|
"username": "bob",
|
||||||
|
"password": "test",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||||
|
self.assertEqual(
|
||||||
|
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIs(False, channel.json_body["approved"])
|
||||||
|
|
||||||
|
# Approve user
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"approved": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIs(True, channel.json_body["approved"])
|
||||||
|
|
||||||
|
# Check that the user is now approved
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertIs(True, channel.json_body["approved"])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_register_approved(self) -> None:
|
||||||
|
url = self.url_prefix % "@bob:test"
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"password": "abc123", "approved": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(1, channel.json_body["approved"])
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(1, channel.json_body["approved"])
|
||||||
|
|
||||||
def _is_erased(self, user_id: str, expect: bool) -> None:
|
def _is_erased(self, user_id: str, expect: bool) -> None:
|
||||||
"""Assert that the user is erased or not"""
|
"""Assert that the user is erased or not"""
|
||||||
d = self.store.is_user_erased(user_id)
|
d = self.store.is_user_erased(user_id)
|
||||||
|
|
|
@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||||
|
from synapse.api.errors import Codes
|
||||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.rest.client import account, auth, devices, login, logout, register
|
from synapse.rest.client import account, auth, devices, login, logout, register
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
|
@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
body={"auth": {"session": session_id}},
|
body={"auth": {"session": session_id}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": TEST_OIDC_CONFIG,
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_sso_not_approved(self) -> None:
|
||||||
|
"""Tests that if we register a user via SSO while requiring approval for new
|
||||||
|
accounts, we still raise the correct error before logging the user in.
|
||||||
|
"""
|
||||||
|
login_resp = self.helper.login_via_oidc("username", expected_status=403)
|
||||||
|
|
||||||
|
self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
|
||||||
|
self.assertEqual(
|
||||||
|
ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we didn't register a device for the user during the login attempt.
|
||||||
|
devices = self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_devices_by_user("@username:test")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(devices), 0)
|
||||||
|
|
||||||
|
|
||||||
class RefreshAuthTests(unittest.HomeserverTestCase):
|
class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||||
|
from synapse.api.errors import Codes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import devices, login, logout, register
|
from synapse.rest.client import devices, login, logout, register
|
||||||
from synapse.rest.client.account import WhoamiRestServlet
|
from synapse.rest.client.account import WhoamiRestServlet
|
||||||
|
@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
logout.register_servlets,
|
logout.register_servlets,
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
|
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
|
||||||
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400)
|
self.assertEqual(channel.code, 400)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_require_approval(self) -> None:
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||||
|
self.assertEqual(
|
||||||
|
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"type": LoginType.PASSWORD,
|
||||||
|
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||||
|
"password": "monkey",
|
||||||
|
}
|
||||||
|
channel = self.make_request("POST", LOGIN_URL, params)
|
||||||
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||||
|
self.assertEqual(
|
||||||
|
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
||||||
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
|
@ -22,6 +22,8 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
|
||||||
|
|
||||||
|
|
||||||
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
@ -45,18 +47,18 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.password = "password"
|
self.password = "password"
|
||||||
|
|
||||||
def test_disabled(self) -> None:
|
def test_disabled(self) -> None:
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
channel = self.make_request("POST", endpoint, {}, access_token=None)
|
||||||
self.assertEqual(channel.code, 400)
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
self.register_user(self.user, self.password)
|
self.register_user(self.user, self.password)
|
||||||
token = self.login(self.user, self.password)
|
token = self.login(self.user, self.password)
|
||||||
|
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||||
self.assertEqual(channel.code, 400)
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||||
def test_require_auth(self) -> None:
|
def test_require_auth(self) -> None:
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
channel = self.make_request("POST", endpoint, {}, access_token=None)
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||||
|
@ -64,7 +66,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
user_id = self.register_user(self.user, self.password)
|
user_id = self.register_user(self.user, self.password)
|
||||||
token = self.login(self.user, self.password)
|
token = self.login(self.user, self.password)
|
||||||
|
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||||
|
|
||||||
|
@ -79,7 +81,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
channel = self.make_request("POST", "/login/token", uia, access_token=token)
|
channel = self.make_request("POST", endpoint, uia, access_token=token)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||||
|
|
||||||
|
@ -100,7 +102,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
user_id = self.register_user(self.user, self.password)
|
user_id = self.register_user(self.user, self.password)
|
||||||
token = self.login(self.user, self.password)
|
token = self.login(self.user, self.password)
|
||||||
|
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||||
|
|
||||||
|
@ -127,6 +129,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.register_user(self.user, self.password)
|
self.register_user(self.user, self.password)
|
||||||
token = self.login(self.user, self.password)
|
token = self.login(self.user, self.password)
|
||||||
|
|
||||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body["expires_in"], 15)
|
self.assertEqual(channel.json_body["expires_in"], 15)
|
||||||
|
|
|
@ -22,7 +22,11 @@ import pkg_resources
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
from synapse.api.constants import (
|
||||||
|
APP_SERVICE_REGISTRATION_TYPE,
|
||||||
|
ApprovalNoticeMedium,
|
||||||
|
LoginType,
|
||||||
|
)
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||||
|
@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 400, channel.json_body)
|
self.assertEqual(channel.code, 400, channel.json_body)
|
||||||
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
|
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_require_approval(self) -> None:
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"register",
|
||||||
|
{
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey",
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(403, channel.code, channel.result)
|
||||||
|
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||||
|
self.assertEqual(
|
||||||
|
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AccountValidityTestCase(unittest.HomeserverTestCase):
|
class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
|
|
@ -543,8 +543,12 @@ class RestHelper:
|
||||||
|
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
def login_via_oidc(
|
||||||
"""Log in (as a new user) via OIDC
|
self,
|
||||||
|
remote_user_id: str,
|
||||||
|
expected_status: int = 200,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""Log in via OIDC
|
||||||
|
|
||||||
Returns the result of the final token login.
|
Returns the result of the final token login.
|
||||||
|
|
||||||
|
@ -578,7 +582,9 @@ class RestHelper:
|
||||||
"/login",
|
"/login",
|
||||||
content={"type": "m.login.token", "token": login_token},
|
content={"type": "m.login.token", "token": login_token},
|
||||||
)
|
)
|
||||||
assert channel.code == HTTPStatus.OK
|
assert (
|
||||||
|
channel.code == expected_status
|
||||||
|
), f"unexpected status in response: {channel.code}"
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
def auth_via_oidc(
|
def auth_via_oidc(
|
||||||
|
|
|
@ -754,18 +754,28 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_backfill_points_in_room(self):
|
def test_get_backfill_points_in_room(self):
|
||||||
"""
|
"""
|
||||||
Test to make sure we get some backfill points
|
Test to make sure only backfill points that are older and come before
|
||||||
|
the `current_depth` are returned.
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_backfill_tests()
|
setup_info = self._setup_room_for_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
|
# Try at "B"
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_backfill_points_in_room(room_id)
|
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(
|
self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"])
|
||||||
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
|
|
||||||
|
# Try at "A"
|
||||||
|
backfill_points = self.get_success(
|
||||||
|
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||||
)
|
)
|
||||||
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
|
# Event "2" has a depth of 2 but is not included here because we only
|
||||||
|
# know the approximate depth of 5 from our event "3".
|
||||||
|
self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||||
|
|
||||||
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
|
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
|
||||||
self,
|
self,
|
||||||
|
@ -776,6 +786,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_backfill_tests()
|
setup_info = self._setup_room_for_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
# Record some attempts to backfill these events which will make
|
# Record some attempts to backfill these events which will make
|
||||||
# `get_backfill_points_in_room` exclude them because we
|
# `get_backfill_points_in_room` exclude them because we
|
||||||
|
@ -795,12 +806,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# No time has passed since we attempted to backfill ^
|
# No time has passed since we attempted to backfill ^
|
||||||
|
|
||||||
|
# Try at "B"
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_backfill_points_in_room(room_id)
|
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
# Only the backfill points that we didn't record earlier exist here.
|
# Only the backfill points that we didn't record earlier exist here.
|
||||||
self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"])
|
self.assertEqual(backfill_event_ids, ["b6", "2", "b1"])
|
||||||
|
|
||||||
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
|
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
|
||||||
self,
|
self,
|
||||||
|
@ -812,6 +824,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_backfill_tests()
|
setup_info = self._setup_room_for_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
# Record some attempts to backfill these events which will make
|
# Record some attempts to backfill these events which will make
|
||||||
# `get_backfill_points_in_room` exclude them because we
|
# `get_backfill_points_in_room` exclude them because we
|
||||||
|
@ -839,27 +852,66 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# visible regardless.
|
# visible regardless.
|
||||||
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
||||||
|
|
||||||
# Make sure that "b1" is not in the list because we've
|
# Try at "A" and make sure that "b1" is not in the list because we've
|
||||||
# already attempted many times
|
# already attempted many times
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_backfill_points_in_room(room_id)
|
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"])
|
self.assertEqual(backfill_event_ids, ["b3", "b2"])
|
||||||
|
|
||||||
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
||||||
# see if we can now backfill it
|
# see if we can now backfill it
|
||||||
self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
|
self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
|
||||||
|
|
||||||
# Try again after we advanced enough time and we should see "b3" again
|
# Try at "A" again after we advanced enough time and we should see "b3" again
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_backfill_points_in_room(room_id)
|
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(
|
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||||
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
|
|
||||||
|
def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A test that reproduces #13929 (Postgres only).
|
||||||
|
|
||||||
|
Test to make sure we can still get backfill points after many failed pull
|
||||||
|
attempts that cause us to backoff to the limit. Even if the backoff formula
|
||||||
|
would tell us to wait for more seconds than can be expressed in a 32 bit
|
||||||
|
signed int.
|
||||||
|
"""
|
||||||
|
setup_info = self._setup_room_for_backfill_tests()
|
||||||
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
|
# Pretend that we have tried and failed 10 times to backfill event b1.
|
||||||
|
for _ in range(10):
|
||||||
|
self.get_success(
|
||||||
|
self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If the backoff periods grow without limit:
|
||||||
|
# After the first failed attempt, we would have backed off for 1 << 1 = 2 hours.
|
||||||
|
# After the second failed attempt we would have backed off for 1 << 2 = 4 hours,
|
||||||
|
# so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours.
|
||||||
|
# Wait 1100 hours just so we have a nice round number.
|
||||||
|
self.reactor.advance(datetime.timedelta(hours=1100).total_seconds())
|
||||||
|
|
||||||
|
# 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit
|
||||||
|
# signed integer. The bug we're reproducing is that this overflow causes an
|
||||||
|
# error in postgres preventing us from fetching a set of backwards extremities
|
||||||
|
# to retry fetching.
|
||||||
|
backfill_points = self.get_success(
|
||||||
|
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should aim to fetch all backoff points: b1's latest backoff period has
|
||||||
|
# expired, and we haven't tried the rest.
|
||||||
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
|
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||||
|
|
||||||
def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
|
def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
|
||||||
"""
|
"""
|
||||||
Sets up a room with various insertion event backward extremities to test
|
Sets up a room with various insertion event backward extremities to test
|
||||||
|
@ -938,18 +990,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room(self):
|
def test_get_insertion_event_backward_extremities_in_room(self):
|
||||||
"""
|
"""
|
||||||
Test to make sure insertion event backward extremities are returned.
|
Test to make sure only insertion event backward extremities that are
|
||||||
|
older and come before the `current_depth` are returned.
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
|
# Try at "insertion_eventB"
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
|
room_id, depth_map["insertion_eventB"], limit=100
|
||||||
|
)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(
|
self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
|
||||||
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
|
|
||||||
|
# Try at "insertion_eventA"
|
||||||
|
backfill_points = self.get_success(
|
||||||
|
self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
|
room_id, depth_map["insertion_eventA"], limit=100
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
|
# Event "2" has a depth of 2 but is not included here because we only
|
||||||
|
# know the approximate depth of 5 from our event "3".
|
||||||
|
self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
|
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
|
||||||
self,
|
self,
|
||||||
|
@ -961,6 +1027,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
# Record some attempts to backfill these events which will make
|
# Record some attempts to backfill these events which will make
|
||||||
# `get_insertion_event_backward_extremities_in_room` exclude them
|
# `get_insertion_event_backward_extremities_in_room` exclude them
|
||||||
|
@ -973,12 +1040,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# No time has passed since we attempted to backfill ^
|
# No time has passed since we attempted to backfill ^
|
||||||
|
|
||||||
|
# Try at "insertion_eventB"
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
|
room_id, depth_map["insertion_eventB"], limit=100
|
||||||
|
)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
# Only the backfill points that we didn't record earlier exist here.
|
# Only the backfill points that we didn't record earlier exist here.
|
||||||
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
|
self.assertEqual(backfill_event_ids, ["insertion_eventB"])
|
||||||
|
|
||||||
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
|
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
|
||||||
self,
|
self,
|
||||||
|
@ -991,6 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||||
room_id = setup_info.room_id
|
room_id = setup_info.room_id
|
||||||
|
depth_map = setup_info.depth_map
|
||||||
|
|
||||||
# Record some attempts to backfill these events which will make
|
# Record some attempts to backfill these events which will make
|
||||||
# `get_backfill_points_in_room` exclude them because we
|
# `get_backfill_points_in_room` exclude them because we
|
||||||
|
@ -1027,13 +1098,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# because we haven't waited long enough for this many attempts.
|
# because we haven't waited long enough for this many attempts.
|
||||||
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
||||||
|
|
||||||
# Make sure that "insertion_eventA" is not in the list because we've
|
# Try at "insertion_eventA" and make sure that "insertion_eventA" is not
|
||||||
# already attempted many times
|
# in the list because we've already attempted many times
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
|
room_id, depth_map["insertion_eventA"], limit=100
|
||||||
|
)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
|
self.assertEqual(backfill_event_ids, [])
|
||||||
|
|
||||||
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
||||||
# see if we can now backfill it
|
# see if we can now backfill it
|
||||||
|
@ -1042,12 +1115,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
# Try at "insertion_eventA" again after we advanced enough time and we
|
# Try at "insertion_eventA" again after we advanced enough time and we
|
||||||
# should see "insertion_eventA" again
|
# should see "insertion_eventA" again
|
||||||
backfill_points = self.get_success(
|
backfill_points = self.get_success(
|
||||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
self.store.get_insertion_event_backward_extremities_in_room(
|
||||||
|
room_id, depth_map["insertion_eventA"], limit=100
|
||||||
|
)
|
||||||
)
|
)
|
||||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||||
self.assertListEqual(
|
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
|
||||||
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Collection, Optional
|
||||||
|
|
||||||
from synapse.api.constants import ReceiptTypes
|
from synapse.api.constants import ReceiptTypes
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_last_unthreaded_receipt(
|
||||||
|
self, receipt_types: Collection[str], room_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
receipt_types: The receipt types to fetch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The latest receipt, if one exists.
|
||||||
|
"""
|
||||||
|
result = self.get_success(
|
||||||
|
self.store.db_pool.runInteraction(
|
||||||
|
"get_last_receipt_event_id_for_user",
|
||||||
|
self.store.get_last_unthreaded_receipt_for_user_txn,
|
||||||
|
OUR_USER_ID,
|
||||||
|
room_id or self.room_id1,
|
||||||
|
receipt_types,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
event_id, _ = result
|
||||||
|
return event_id
|
||||||
|
|
||||||
def test_return_empty_with_no_data(self) -> None:
|
def test_return_empty_with_no_data(self) -> None:
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.store.get_receipts_for_user(
|
self.store.get_receipts_for_user(
|
||||||
|
@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(res, {})
|
self.assertEqual(res, {})
|
||||||
|
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt(
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
||||||
OUR_USER_ID,
|
|
||||||
self.room_id1,
|
|
||||||
[
|
|
||||||
ReceiptTypes.READ,
|
|
||||||
ReceiptTypes.READ_PRIVATE,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(res, None)
|
self.assertEqual(res, None)
|
||||||
|
|
||||||
def test_get_receipts_for_user(self) -> None:
|
def test_get_receipts_for_user(self) -> None:
|
||||||
|
@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test we get the latest event when we want both private and public receipts
|
# Test we get the latest event when we want both private and public receipts
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt(
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
||||||
OUR_USER_ID,
|
|
||||||
self.room_id1,
|
|
||||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.assertEqual(res, event1_2_id)
|
self.assertEqual(res, event1_2_id)
|
||||||
|
|
||||||
# Test we get the older event when we want only public receipt
|
# Test we get the older event when we want only public receipt
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
|
||||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(res, event1_1_id)
|
self.assertEqual(res, event1_1_id)
|
||||||
|
|
||||||
# Test we get the latest event when we want only the private receipt
|
# Test we get the latest event when we want only the private receipt
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
|
||||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(res, event1_2_id)
|
self.assertEqual(res, event1_2_id)
|
||||||
|
|
||||||
# Test receipt updating
|
# Test receipt updating
|
||||||
|
@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
|
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
|
||||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.assertEqual(res, event1_2_id)
|
self.assertEqual(res, event1_2_id)
|
||||||
|
|
||||||
# Send some events into the second room
|
# Send some events into the second room
|
||||||
|
@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res = self.get_success(
|
res = self.get_last_unthreaded_receipt(
|
||||||
self.store.get_last_receipt_event_id_for_user(
|
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
|
||||||
OUR_USER_ID,
|
|
||||||
self.room_id2,
|
|
||||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.assertEqual(res, event2_1_id)
|
self.assertEqual(res, event2_1_id)
|
||||||
|
|
|
@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import ThreepidValidationError
|
from synapse.api.errors import ThreepidValidationError
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStoreTestCase(HomeserverTestCase):
|
class RegistrationStoreTestCase(HomeserverTestCase):
|
||||||
|
@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
||||||
"user_type": None,
|
"user_type": None,
|
||||||
"deactivated": 0,
|
"deactivated": 0,
|
||||||
"shadow_banned": 0,
|
"shadow_banned": 0,
|
||||||
|
"approved": 1,
|
||||||
},
|
},
|
||||||
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
||||||
)
|
)
|
||||||
|
@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
||||||
ThreepidValidationError,
|
ThreepidValidationError,
|
||||||
)
|
)
|
||||||
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
|
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
|
||||||
|
# If there's already some config for this feature in the default config, it
|
||||||
|
# means we're overriding it with @override_config. In this case we don't want
|
||||||
|
# to do anything more with it.
|
||||||
|
msc3866_config = config.get("experimental_features", {}).get("msc3866")
|
||||||
|
if msc3866_config is not None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Require approval for all new accounts.
|
||||||
|
config["experimental_features"] = {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.user_id = "@my-user:test"
|
||||||
|
self.pwhash = "{xx1}123456789"
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"experimental_features": {
|
||||||
|
"msc3866": {
|
||||||
|
"enabled": True,
|
||||||
|
"require_approval_for_new_accounts": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_approval_not_required(self) -> None:
|
||||||
|
"""Tests that if we don't require approval for new accounts, newly created
|
||||||
|
accounts are automatically marked as approved.
|
||||||
|
"""
|
||||||
|
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||||
|
|
||||||
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
|
assert user is not None
|
||||||
|
self.assertTrue(user["approved"])
|
||||||
|
|
||||||
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
|
self.assertTrue(approved)
|
||||||
|
|
||||||
|
def test_approval_required(self) -> None:
|
||||||
|
"""Tests that if we require approval for new accounts, newly created accounts
|
||||||
|
are not automatically marked as approved.
|
||||||
|
"""
|
||||||
|
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||||
|
|
||||||
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
|
assert user is not None
|
||||||
|
self.assertFalse(user["approved"])
|
||||||
|
|
||||||
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
|
self.assertFalse(approved)
|
||||||
|
|
||||||
|
def test_override(self) -> None:
|
||||||
|
"""Tests that if we require approval for new accounts, but we explicitly say the
|
||||||
|
new user should be considered approved, they're marked as approved.
|
||||||
|
"""
|
||||||
|
self.get_success(
|
||||||
|
self.store.register_user(
|
||||||
|
self.user_id,
|
||||||
|
self.pwhash,
|
||||||
|
approved=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||||
|
self.assertIsNotNone(user)
|
||||||
|
assert user is not None
|
||||||
|
self.assertEqual(user["approved"], 1)
|
||||||
|
|
||||||
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
|
self.assertTrue(approved)
|
||||||
|
|
||||||
|
def test_approve_user(self) -> None:
|
||||||
|
"""Tests that approving the user updates their approval status."""
|
||||||
|
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||||
|
|
||||||
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
|
self.assertFalse(approved)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.store.update_user_approval_status(
|
||||||
|
UserID.from_string(self.user_id), True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||||
|
self.assertTrue(approved)
|
||||||
|
|
|
@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase):
|
||||||
def mock_installed_package(
|
def mock_installed_package(
|
||||||
self, distribution: Optional[DummyDistribution]
|
self, distribution: Optional[DummyDistribution]
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""Pretend that looking up any distribution yields the given `distribution`."""
|
"""Pretend that looking up any package yields the given `distribution`.
|
||||||
|
|
||||||
|
If `distribution = None`, we pretend that the package is not installed.
|
||||||
|
"""
|
||||||
|
|
||||||
def mock_distribution(name: str):
|
def mock_distribution(name: str):
|
||||||
if distribution is None:
|
if distribution is None:
|
||||||
|
@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase):
|
||||||
self.assertRaises(DependencyException, check_requirements)
|
self.assertRaises(DependencyException, check_requirements)
|
||||||
|
|
||||||
def test_checks_ignore_dev_dependencies(self) -> None:
|
def test_checks_ignore_dev_dependencies(self) -> None:
|
||||||
"""Bot generic and per-extra checks should ignore dev dependencies."""
|
"""Both generic and per-extra checks should ignore dev dependencies."""
|
||||||
with patch(
|
with patch(
|
||||||
"synapse.util.check_dependencies.metadata.requires",
|
"synapse.util.check_dependencies.metadata.requires",
|
||||||
return_value=["dummypkg >= 1; extra == 'mypy'"],
|
return_value=["dummypkg >= 1; extra == 'mypy'"],
|
||||||
|
@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase):
|
||||||
with self.mock_installed_package(new_release_candidate):
|
with self.mock_installed_package(new_release_candidate):
|
||||||
# should not raise
|
# should not raise
|
||||||
check_requirements()
|
check_requirements()
|
||||||
|
|
||||||
|
def test_setuptools_rust_ignored(self) -> None:
|
||||||
|
"""Test a workaround for a `poetry build` problem. Reproduces #13926."""
|
||||||
|
with patch(
|
||||||
|
"synapse.util.check_dependencies.metadata.requires",
|
||||||
|
return_value=["setuptools_rust >= 1.3"],
|
||||||
|
):
|
||||||
|
with self.mock_installed_package(None):
|
||||||
|
# should not raise, even if setuptools_rust is not installed
|
||||||
|
check_requirements()
|
||||||
|
with self.mock_installed_package(old):
|
||||||
|
# We also ignore old versions of setuptools_rust
|
||||||
|
check_requirements()
|
||||||
|
|
Loading…
Reference in New Issue