Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
a2b6ee7b00
|
@ -0,0 +1 @@
|
|||
Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)).
|
|
@ -0,0 +1 @@
|
|||
Send invite push notifications for invite over federation.
|
|
@ -0,0 +1 @@
|
|||
Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -0,0 +1 @@
|
|||
Port push rules to using Rust.
|
|
@ -0,0 +1 @@
|
|||
Fix unstable MSC3882 endpoint being incorrectly available on stable API versions.
|
|
@ -0,0 +1 @@
|
|||
Only pull relevant backfill points from the database based on the current depth and limit (instead of all) every time we want to `/backfill`.
|
|
@ -0,0 +1 @@
|
|||
Improve backfill robustness by trying more servers when we get a `4xx` error back.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug introduced in 1.66 where some required fields in the pushrules sent to clients were not present anymore. Contributed by Nico.
|
|
@ -0,0 +1 @@
|
|||
Faster remote room joins: correctly handle remote device list updates during a partial join.
|
|
@ -0,0 +1 @@
|
|||
Update an innaccurate comment in Synapse's upsert database helper.
|
|
@ -0,0 +1 @@
|
|||
Add instruction to contributing guide for running unit tests in parallel. Contributed by @ashfame.
|
|
@ -0,0 +1 @@
|
|||
Update the man page for the `hash_password` script to correct the default number of bcrypt rounds performed.
|
|
@ -0,0 +1 @@
|
|||
Clarify that the `auto_join_rooms` config option can also be used with Space aliases.
|
|
@ -0,0 +1 @@
|
|||
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
|||
Correctly handle sending local device list updates to remote servers during a partial join.
|
|
@ -0,0 +1 @@
|
|||
Exponentially backoff from backfilling the same event over and over.
|
|
@ -0,0 +1 @@
|
|||
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
|||
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
|
@ -0,0 +1 @@
|
|||
Add cache invalidation across workers to module API.
|
|
@ -0,0 +1 @@
|
|||
Fix a bug introduced in v1.68.0 where Synapse would require `setuptools_rust` at runtime, even though the package is only required at build time.
|
|
@ -0,0 +1 @@
|
|||
Ask mail servers receiving emails from Synapse to not send automatic reply (e.g. out-of-office responses).
|
|
@ -0,0 +1 @@
|
|||
Fix a performance regression in the `get_users_in_room` database query. Introduced in v1.67.0.
|
|
@ -0,0 +1 @@
|
|||
Speed up calculating push actions in large rooms.
|
|
@ -0,0 +1 @@
|
|||
Add some cross references to worker documentation.
|
|
@ -10,7 +10,7 @@
|
|||
.P
|
||||
\fBhash_password\fR takes a password as an parameter either on the command line or the \fBSTDIN\fR if not supplied\.
|
||||
.P
|
||||
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB10\fR\.
|
||||
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB12\fR\.
|
||||
.P
|
||||
The hashed password is written on the \fBSTDOUT\fR\.
|
||||
.SH "FILES"
|
||||
|
|
|
@ -167,6 +167,12 @@ was broken. They are slower than the linters but will typically catch more error
|
|||
poetry run trial tests
|
||||
```
|
||||
|
||||
You can run unit tests in parallel by specifying `-jX` argument to `trial` where `X` is the number of parallel runners you want. To use 4 cpu cores, you would run them like:
|
||||
|
||||
```sh
|
||||
poetry run trial -j4 tests
|
||||
```
|
||||
|
||||
If you wish to only run *some* unit tests, you may specify
|
||||
another module instead of `tests` - or a test class or a method:
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
worker_app: synapse.app.media_repository
|
||||
worker_name: media_worker
|
||||
|
||||
# The replication listener on the main synapse process.
|
||||
worker_replication_host: 127.0.0.1
|
||||
worker_replication_http_port: 9093
|
||||
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: 8085
|
||||
resources:
|
||||
- names: [media]
|
||||
|
||||
worker_log_config: /etc/matrix-synapse/media-worker-log.yaml
|
|
@ -88,6 +88,18 @@ process, for example:
|
|||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
```
|
||||
|
||||
# Upgrading to v1.69.0
|
||||
|
||||
## Changes to the receipts replication streams
|
||||
|
||||
Synapse now includes information indicating if a receipt applies to a thread when
|
||||
replicating it to other workers. This is a forwards- and backwards-incompatible
|
||||
change: v1.68 and workers cannot process receipts replicated by v1.69 workers, and
|
||||
vice versa.
|
||||
|
||||
Once all workers are upgraded to v1.69 (or downgraded to v1.68), receipts
|
||||
replication will resume as normal.
|
||||
|
||||
# Upgrading to v1.68.0
|
||||
|
||||
Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0.
|
||||
|
|
|
@ -2229,6 +2229,9 @@ homeserver. If the room already exists, make certain it is a publicly joinable
|
|||
room, i.e. the join rule of the room must be set to 'public'. You can find more options
|
||||
relating to auto-joining rooms below.
|
||||
|
||||
As Spaces are just rooms under the hood, Space aliases may also be
|
||||
used.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
auto_join_rooms:
|
||||
|
@ -2240,7 +2243,7 @@ auto_join_rooms:
|
|||
|
||||
Where `auto_join_rooms` are specified, setting this flag ensures that
|
||||
the rooms exist by creating them when the first user on the
|
||||
homeserver registers.
|
||||
homeserver registers. This option will not create Spaces.
|
||||
|
||||
By default the auto-created rooms are publicly joinable from any federated
|
||||
server. Use the `autocreate_auto_join_rooms_federated` and
|
||||
|
@ -2258,7 +2261,7 @@ autocreate_auto_join_rooms: false
|
|||
---
|
||||
### `autocreate_auto_join_rooms_federated`
|
||||
|
||||
Whether the rooms listen in `auto_join_rooms` that are auto-created are available
|
||||
Whether the rooms listed in `auto_join_rooms` that are auto-created are available
|
||||
via federation. Only has an effect if `autocreate_auto_join_rooms` is true.
|
||||
|
||||
Note that whether a room is federated cannot be modified after
|
||||
|
|
|
@ -93,7 +93,6 @@ listener" for the main process; and secondly, you need to enable redis-based
|
|||
replication. Optionally, a shared secret can be used to authenticate HTTP
|
||||
traffic between workers. For example:
|
||||
|
||||
|
||||
```yaml
|
||||
# extend the existing `listeners` section. This defines the ports that the
|
||||
# main process will listen on.
|
||||
|
@ -129,7 +128,8 @@ In the config file for each worker, you must specify:
|
|||
* The HTTP replication endpoint that it should talk to on the main synapse process
|
||||
(`worker_replication_host` and `worker_replication_http_port`)
|
||||
* If handling HTTP requests, a `worker_listeners` option with an `http`
|
||||
listener, in the same way as the `listeners` option in the shared config.
|
||||
listener, in the same way as the [`listeners`](usage/configuration/config_documentation.md#listeners)
|
||||
option in the shared config.
|
||||
* If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for
|
||||
the main process (`worker_main_http_uri`).
|
||||
|
||||
|
@ -285,8 +285,9 @@ For multiple workers not handling the SSO endpoints properly, see
|
|||
[#7530](https://github.com/matrix-org/synapse/issues/7530) and
|
||||
[#9427](https://github.com/matrix-org/synapse/issues/9427).
|
||||
|
||||
Note that a HTTP listener with `client` and `federation` resources must be
|
||||
configured in the `worker_listeners` option in the worker config.
|
||||
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
|
||||
with `client` and `federation` `resources` must be configured in the `worker_listeners`
|
||||
option in the worker config.
|
||||
|
||||
#### Load balancing
|
||||
|
||||
|
@ -326,7 +327,8 @@ effects of bursts of events from that bridge on events sent by normal users.
|
|||
Additionally, the writing of specific streams (such as events) can be moved off
|
||||
of the main process to a particular worker.
|
||||
|
||||
To enable this, the worker must have a HTTP replication listener configured,
|
||||
To enable this, the worker must have a
|
||||
[HTTP `replication` listener](usage/configuration/config_documentation.md#listeners) configured,
|
||||
have a `worker_name` and be listed in the `instance_map` config. The same worker
|
||||
can handle multiple streams, but unless otherwise documented, each stream can only
|
||||
have a single writer.
|
||||
|
@ -410,7 +412,7 @@ the stream writer for the `presence` stream:
|
|||
There is also support for moving background tasks to a separate
|
||||
worker. Background tasks are run periodically or started via replication. Exactly
|
||||
which tasks are configured to run depends on your Synapse configuration (e.g. if
|
||||
stats is enabled).
|
||||
stats is enabled). This worker doesn't handle any REST endpoints itself.
|
||||
|
||||
To enable this, the worker must have a `worker_name` and can be configured to run
|
||||
background tasks. For example, to move background tasks to a dedicated worker,
|
||||
|
@ -457,8 +459,8 @@ worker application type.
|
|||
#### Notifying Application Services
|
||||
|
||||
You can designate one generic worker to send output traffic to Application Services.
|
||||
|
||||
Specify its name in the shared configuration as follows:
|
||||
Doesn't handle any REST endpoints itself, but you should specify its name in the
|
||||
shared configuration as follows:
|
||||
|
||||
```yaml
|
||||
notify_appservices_from_worker: worker_name
|
||||
|
@ -536,16 +538,12 @@ file to stop the main synapse running background jobs related to managing the
|
|||
media repository. Note that doing so will prevent the main process from being
|
||||
able to handle the above endpoints.
|
||||
|
||||
In the `media_repository` worker configuration file, configure the http listener to
|
||||
In the `media_repository` worker configuration file, configure the
|
||||
[HTTP listener](usage/configuration/config_documentation.md#listeners) to
|
||||
expose the `media` resource. For example:
|
||||
|
||||
```yaml
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: 8085
|
||||
resources:
|
||||
- names:
|
||||
- media
|
||||
{{#include systemd-with-workers/workers/media_worker.yaml}}
|
||||
```
|
||||
|
||||
Note that if running multiple media repositories they must be on the same server
|
||||
|
|
|
@ -11,7 +11,9 @@ rust-version = "1.58.1"
|
|||
|
||||
[lib]
|
||||
name = "synapse"
|
||||
crate-type = ["cdylib"]
|
||||
# We generate a `cdylib` for Python and a standard `lib` for running
|
||||
# tests/benchmarks.
|
||||
crate-type = ["lib", "cdylib"]
|
||||
|
||||
[package.metadata.maturin]
|
||||
# This is where we tell maturin where to place the built library.
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![feature(test)]
|
||||
use synapse::push::{
|
||||
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
|
||||
};
|
||||
use test::Bencher;
|
||||
|
||||
extern crate test;
|
||||
|
||||
#[bench]
|
||||
fn bench_match_exact(b: &mut Bencher) {
|
||||
let flattened_keys = [
|
||||
("type".to_string(), "m.text".to_string()),
|
||||
("room_id".to_string(), "!room:server".to_string()),
|
||||
("content.body".to_string(), "test message".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let eval = PushRuleEvaluator::py_new(
|
||||
flattened_keys,
|
||||
10,
|
||||
0,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||
EventMatchCondition {
|
||||
key: "room_id".into(),
|
||||
pattern: Some("!room:server".into()),
|
||||
pattern_type: None,
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
assert!(matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_match_word(b: &mut Bencher) {
|
||||
let flattened_keys = [
|
||||
("type".to_string(), "m.text".to_string()),
|
||||
("room_id".to_string(), "!room:server".to_string()),
|
||||
("content.body".to_string(), "test message".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let eval = PushRuleEvaluator::py_new(
|
||||
flattened_keys,
|
||||
10,
|
||||
0,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||
EventMatchCondition {
|
||||
key: "content.body".into(),
|
||||
pattern: Some("test".into()),
|
||||
pattern_type: None,
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
assert!(matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_match_word_miss(b: &mut Bencher) {
|
||||
let flattened_keys = [
|
||||
("type".to_string(), "m.text".to_string()),
|
||||
("room_id".to_string(), "!room:server".to_string()),
|
||||
("content.body".to_string(), "test message".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let eval = PushRuleEvaluator::py_new(
|
||||
flattened_keys,
|
||||
10,
|
||||
0,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
|
||||
EventMatchCondition {
|
||||
key: "content.body".into(),
|
||||
pattern: Some("foobar".into()),
|
||||
pattern_type: None,
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
assert!(!matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_eval_message(b: &mut Bencher) {
|
||||
let flattened_keys = [
|
||||
("type".to_string(), "m.text".to_string()),
|
||||
("room_id".to_string(), "!room:server".to_string()),
|
||||
("content.body".to_string(), "test message".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let eval = PushRuleEvaluator::py_new(
|
||||
flattened_keys,
|
||||
10,
|
||||
0,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let rules =
|
||||
FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false);
|
||||
|
||||
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![feature(test)]
|
||||
|
||||
use synapse::push::utils::{glob_to_regex, GlobMatchType};
|
||||
use test::Bencher;
|
||||
|
||||
extern crate test;
|
||||
|
||||
#[bench]
|
||||
fn bench_whole(b: &mut Bencher) {
|
||||
b.iter(|| glob_to_regex("test", GlobMatchType::Whole));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_word(b: &mut Bencher) {
|
||||
b.iter(|| glob_to_regex("test", GlobMatchType::Word));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_whole_wildcard_run(b: &mut Bencher) {
|
||||
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_word_wildcard_run(b: &mut Bencher) {
|
||||
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
|
||||
}
|
|
@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
|
||||
for entry in entries {
|
||||
if entry.is_dir() {
|
||||
dirs.push(entry)
|
||||
dirs.push(entry);
|
||||
} else {
|
||||
paths.push(entry.to_str().expect("valid rust paths").to_string());
|
||||
}
|
||||
|
|
|
@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
|
|||
priority_class: 1,
|
||||
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch {
|
||||
rel_type: Cow::Borrowed("m.thread"),
|
||||
event_type_pattern: None,
|
||||
sender: None,
|
||||
sender_type: Some(Cow::Borrowed("user_id")),
|
||||
})]),
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Error};
|
||||
use lazy_static::lazy_static;
|
||||
use log::warn;
|
||||
use pyo3::prelude::*;
|
||||
use regex::Regex;
|
||||
|
||||
use super::{
|
||||
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
|
||||
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
/// Used to parse the `is` clause in the room member count condition.
|
||||
static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex");
|
||||
}
|
||||
|
||||
/// Allows running a set of push rules against a particular event.
|
||||
#[pyclass]
|
||||
pub struct PushRuleEvaluator {
|
||||
/// A mapping of "flattened" keys to string values in the event, e.g.
|
||||
/// includes things like "type" and "content.msgtype".
|
||||
flattened_keys: BTreeMap<String, String>,
|
||||
|
||||
/// The "content.body", if any.
|
||||
body: String,
|
||||
|
||||
/// The number of users in the room.
|
||||
room_member_count: u64,
|
||||
|
||||
/// The `notifications` section of the current power levels in the room.
|
||||
notification_power_levels: BTreeMap<String, i64>,
|
||||
|
||||
/// The relations related to the event as a mapping from relation type to
|
||||
/// set of sender/event type 2-tuples.
|
||||
relations: BTreeMap<String, BTreeSet<(String, String)>>,
|
||||
|
||||
/// Is running "relation" conditions enabled?
|
||||
relation_match_enabled: bool,
|
||||
|
||||
/// The power level of the sender of the event, or None if event is an
|
||||
/// outlier.
|
||||
sender_power_level: Option<i64>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PushRuleEvaluator {
|
||||
/// Create a new `PushRuleEvaluator`. See struct docstring for details.
|
||||
#[new]
|
||||
pub fn py_new(
|
||||
flattened_keys: BTreeMap<String, String>,
|
||||
room_member_count: u64,
|
||||
sender_power_level: Option<i64>,
|
||||
notification_power_levels: BTreeMap<String, i64>,
|
||||
relations: BTreeMap<String, BTreeSet<(String, String)>>,
|
||||
relation_match_enabled: bool,
|
||||
) -> Result<Self, Error> {
|
||||
let body = flattened_keys
|
||||
.get("content.body")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(PushRuleEvaluator {
|
||||
flattened_keys,
|
||||
body,
|
||||
room_member_count,
|
||||
notification_power_levels,
|
||||
relations,
|
||||
relation_match_enabled,
|
||||
sender_power_level,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the evaluator with the given push rules, for the given user ID and
|
||||
/// display name of the user.
|
||||
///
|
||||
/// Passing in None will skip evaluating rules matching user ID and display
|
||||
/// name.
|
||||
///
|
||||
/// Returns the set of actions, if any, that match (filtering out any
|
||||
/// `dont_notify` actions).
|
||||
pub fn run(
|
||||
&self,
|
||||
push_rules: &FilteredPushRules,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
) -> Vec<Action> {
|
||||
'outer: for (push_rule, enabled) in push_rules.iter() {
|
||||
if !enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
for condition in push_rule.conditions.iter() {
|
||||
match self.match_condition(condition, user_id, display_name) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => continue 'outer,
|
||||
Err(err) => {
|
||||
warn!("Condition match failed {err}");
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let actions = push_rule
|
||||
.actions
|
||||
.iter()
|
||||
// Filter out "dont_notify" actions, as we don't store them.
|
||||
.filter(|a| **a != Action::DontNotify)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
return actions;
|
||||
}
|
||||
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Check if the given condition matches.
|
||||
fn matches(
|
||||
&self,
|
||||
condition: Condition,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
) -> bool {
|
||||
match self.match_condition(&condition, user_id, display_name) {
|
||||
Ok(true) => true,
|
||||
Ok(false) => false,
|
||||
Err(err) => {
|
||||
warn!("Condition match failed {err}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PushRuleEvaluator {
|
||||
/// Match a given `Condition` for a push rule.
|
||||
pub fn match_condition(
|
||||
&self,
|
||||
condition: &Condition,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
) -> Result<bool, Error> {
|
||||
let known_condition = match condition {
|
||||
Condition::Known(known) => known,
|
||||
Condition::Unknown(_) => {
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let result = match known_condition {
|
||||
KnownCondition::EventMatch(event_match) => {
|
||||
self.match_event_match(event_match, user_id)?
|
||||
}
|
||||
KnownCondition::ContainsDisplayName => {
|
||||
if let Some(dn) = display_name {
|
||||
if !dn.is_empty() {
|
||||
get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)?
|
||||
} else {
|
||||
// We specifically ignore empty display names, as otherwise
|
||||
// they would always match.
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
KnownCondition::RoomMemberCount { is } => {
|
||||
if let Some(is) = is {
|
||||
self.match_member_count(is)?
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
KnownCondition::SenderNotificationPermission { key } => {
|
||||
if let Some(sender_power_level) = &self.sender_power_level {
|
||||
let required_level = self
|
||||
.notification_power_levels
|
||||
.get(key.as_ref())
|
||||
.copied()
|
||||
.unwrap_or(50);
|
||||
|
||||
*sender_power_level >= required_level
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
KnownCondition::RelationMatch {
|
||||
rel_type,
|
||||
event_type_pattern,
|
||||
sender,
|
||||
sender_type,
|
||||
} => {
|
||||
self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Evaluates a relation condition.
|
||||
fn match_relations(
|
||||
&self,
|
||||
rel_type: &str,
|
||||
sender: &Option<Cow<str>>,
|
||||
sender_type: &Option<Cow<str>>,
|
||||
user_id: Option<&str>,
|
||||
event_type_pattern: &Option<Cow<str>>,
|
||||
) -> Result<bool, Error> {
|
||||
// First check if relation matching is enabled...
|
||||
if !self.relation_match_enabled {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// ... and if there are any relations to match against.
|
||||
let relations = if let Some(relations) = self.relations.get(rel_type) {
|
||||
relations
|
||||
} else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
// Extract the sender pattern from the condition
|
||||
let sender_pattern = if let Some(sender) = sender {
|
||||
Some(sender.as_ref())
|
||||
} else if let Some(sender_type) = sender_type {
|
||||
if sender_type == "user_id" {
|
||||
if let Some(user_id) = user_id {
|
||||
Some(user_id)
|
||||
} else {
|
||||
return Ok(false);
|
||||
}
|
||||
} else {
|
||||
warn!("Unrecognized sender_type: {sender_type}");
|
||||
return Ok(false);
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern {
|
||||
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern {
|
||||
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
for (relation_sender, event_type) in relations {
|
||||
if let Some(pattern) = &mut sender_compiled_pattern {
|
||||
if !pattern.is_match(relation_sender)? {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pattern) = &mut type_compiled_pattern {
|
||||
if !pattern.is_match(event_type)? {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Evaluates a `event_match` condition.
|
||||
fn match_event_match(
|
||||
&self,
|
||||
event_match: &EventMatchCondition,
|
||||
user_id: Option<&str>,
|
||||
) -> Result<bool, Error> {
|
||||
let pattern = if let Some(pattern) = &event_match.pattern {
|
||||
pattern
|
||||
} else if let Some(pattern_type) = &event_match.pattern_type {
|
||||
// The `pattern_type` can either be "user_id" or "user_localpart",
|
||||
// either way if we don't have a `user_id` then the condition can't
|
||||
// match.
|
||||
let user_id = if let Some(user_id) = user_id {
|
||||
user_id
|
||||
} else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
match &**pattern_type {
|
||||
"user_id" => user_id,
|
||||
"user_localpart" => get_localpart_from_id(user_id)?,
|
||||
_ => return Ok(false),
|
||||
}
|
||||
} else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) {
|
||||
haystack
|
||||
} else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
// For the content.body we match against "words", but for everything
|
||||
// else we match against the entire value.
|
||||
let match_type = if event_match.key == "content.body" {
|
||||
GlobMatchType::Word
|
||||
} else {
|
||||
GlobMatchType::Whole
|
||||
};
|
||||
|
||||
let mut compiled_pattern = get_glob_matcher(pattern, match_type)?;
|
||||
compiled_pattern.is_match(haystack)
|
||||
}
|
||||
|
||||
/// Match the member count against an 'is' condition
|
||||
/// The `is` condition can be things like '>2', '==3' or even just '4'.
|
||||
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
|
||||
let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?;
|
||||
let ineq = captures.get(1).map_or("==", |m| m.as_str());
|
||||
let rhs: u64 = captures
|
||||
.get(2)
|
||||
.context("missing number")?
|
||||
.as_str()
|
||||
.parse()?;
|
||||
|
||||
let matches = match ineq {
|
||||
"" | "==" => self.room_member_count == rhs,
|
||||
"<" => self.room_member_count < rhs,
|
||||
">" => self.room_member_count > rhs,
|
||||
">=" => self.room_member_count >= rhs,
|
||||
"<=" => self.room_member_count <= rhs,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
Ok(matches)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_rule_evaluator() {
|
||||
let mut flattened_keys = BTreeMap::new();
|
||||
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
|
||||
let evaluator = PushRuleEvaluator::py_new(
|
||||
flattened_keys,
|
||||
10,
|
||||
Some(0),
|
||||
BTreeMap::new(),
|
||||
BTreeMap::new(),
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
|
@ -42,7 +42,6 @@
|
|||
//!
|
||||
//! The set of "base rules" are the list of rules that every user has by default. A
|
||||
//! user can modify their copy of the push rules in one of three ways:
|
||||
//!
|
||||
//! 1. Adding a new push rule of a certain kind
|
||||
//! 2. Changing the actions of a base rule
|
||||
//! 3. Enabling/disabling a base rule.
|
||||
|
@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet};
|
|||
use anyhow::{Context, Error};
|
||||
use log::warn;
|
||||
use pyo3::prelude::*;
|
||||
use pythonize::pythonize;
|
||||
use pythonize::{depythonize, pythonize};
|
||||
use serde::de::Error as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use self::evaluator::PushRuleEvaluator;
|
||||
|
||||
mod base_rules;
|
||||
pub mod evaluator;
|
||||
pub mod utils;
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
|
@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
child_module.add_class::<PushRule>()?;
|
||||
child_module.add_class::<PushRules>()?;
|
||||
child_module.add_class::<FilteredPushRules>()?;
|
||||
child_module.add_class::<PushRuleEvaluator>()?;
|
||||
child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
|
||||
|
||||
m.add_submodule(child_module)?;
|
||||
|
@ -274,6 +278,8 @@ pub enum KnownCondition {
|
|||
#[serde(rename = "org.matrix.msc3772.relation_match")]
|
||||
RelationMatch {
|
||||
rel_type: Cow<'static, str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
event_type_pattern: Option<Cow<'static, str>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
sender: Option<Cow<'static, str>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
@ -287,20 +293,26 @@ impl IntoPy<PyObject> for Condition {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'source> FromPyObject<'source> for Condition {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
Ok(depythonize(ob)?)
|
||||
}
|
||||
}
|
||||
|
||||
/// The body of a [`Condition::EventMatch`]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct EventMatchCondition {
|
||||
key: Cow<'static, str>,
|
||||
pub key: Cow<'static, str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pattern: Option<Cow<'static, str>>,
|
||||
pub pattern: Option<Cow<'static, str>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pattern_type: Option<Cow<'static, str>>,
|
||||
pub pattern_type: Option<Cow<'static, str>>,
|
||||
}
|
||||
|
||||
/// The collection of push rules for a user.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[pyclass(frozen)]
|
||||
struct PushRules {
|
||||
pub struct PushRules {
|
||||
/// Custom push rules that override a base rule.
|
||||
overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
|
||||
|
||||
|
@ -319,7 +331,7 @@ struct PushRules {
|
|||
#[pymethods]
|
||||
impl PushRules {
|
||||
#[new]
|
||||
fn new(rules: Vec<PushRule>) -> PushRules {
|
||||
pub fn new(rules: Vec<PushRule>) -> PushRules {
|
||||
let mut push_rules: PushRules = Default::default();
|
||||
|
||||
for rule in rules {
|
||||
|
@ -396,7 +408,7 @@ pub struct FilteredPushRules {
|
|||
#[pymethods]
|
||||
impl FilteredPushRules {
|
||||
#[new]
|
||||
fn py_new(
|
||||
pub fn py_new(
|
||||
push_rules: PushRules,
|
||||
enabled_map: BTreeMap<String, bool>,
|
||||
msc3786_enabled: bool,
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::bail;
|
||||
use anyhow::Context;
|
||||
use anyhow::Error;
|
||||
use lazy_static::lazy_static;
|
||||
use regex;
|
||||
use regex::Regex;
|
||||
use regex::RegexBuilder;
|
||||
|
||||
lazy_static! {
|
||||
/// Matches runs of non-wildcard characters followed by wildcard characters.
|
||||
static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex");
|
||||
}
|
||||
|
||||
/// Extract the localpart from a Matrix style ID
|
||||
pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> {
|
||||
let (localpart, _) = id
|
||||
.split_once(':')
|
||||
.with_context(|| format!("ID does not contain colon: {id}"))?;
|
||||
|
||||
// We need to strip off the first character, which is the ID type.
|
||||
if localpart.is_empty() {
|
||||
bail!("Invalid ID {id}");
|
||||
}
|
||||
|
||||
Ok(&localpart[1..])
|
||||
}
|
||||
|
||||
/// Used by `glob_to_regex` to specify what to match the regex against.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GlobMatchType {
|
||||
/// The generated regex will match against the entire input.
|
||||
Whole,
|
||||
/// The generated regex will match against words.
|
||||
Word,
|
||||
}
|
||||
|
||||
/// Convert a "glob" style expression to a regex, anchoring either to the entire
|
||||
/// input or to individual words.
|
||||
pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result<Regex, Error> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
// Patterns with wildcards must be simplified to avoid performance cliffs
|
||||
// - The glob `?**?**?` is equivalent to the glob `???*`
|
||||
// - The glob `???*` is equivalent to the regex `.{3,}`
|
||||
for captures in WILDCARD_RUN.captures_iter(glob) {
|
||||
if let Some(chunk) = captures.get(1) {
|
||||
chunks.push(regex::escape(chunk.as_str()));
|
||||
}
|
||||
|
||||
if let Some(wildcards) = captures.get(2) {
|
||||
if wildcards.as_str() == "" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count();
|
||||
|
||||
if wildcards.as_str().contains('*') {
|
||||
chunks.push(format!(".{{{question_marks},}}"));
|
||||
} else {
|
||||
chunks.push(format!(".{{{question_marks}}}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let joined = chunks.join("");
|
||||
|
||||
let regex_str = match match_type {
|
||||
GlobMatchType::Whole => format!(r"\A{joined}\z"),
|
||||
|
||||
// `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word
|
||||
// character.
|
||||
GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"),
|
||||
};
|
||||
|
||||
Ok(RegexBuilder::new(®ex_str)
|
||||
.case_insensitive(true)
|
||||
.build()?)
|
||||
}
|
||||
|
||||
/// Compiles the glob into a `Matcher`.
|
||||
pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result<Matcher, Error> {
|
||||
// There are a number of shortcuts we can make if the glob doesn't contain a
|
||||
// wild card.
|
||||
let matcher = if glob.contains(['*', '?']) {
|
||||
let regex = glob_to_regex(glob, match_type)?;
|
||||
Matcher::Regex(regex)
|
||||
} else if match_type == GlobMatchType::Whole {
|
||||
// If there aren't any wildcards and we're matching the whole thing,
|
||||
// then we simply can do a case-insensitive string match.
|
||||
Matcher::Whole(glob.to_lowercase())
|
||||
} else {
|
||||
// Otherwise, if we're matching against words then can first check
|
||||
// if the haystack contains the glob at all.
|
||||
Matcher::Word {
|
||||
word: glob.to_lowercase(),
|
||||
regex: None,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(matcher)
|
||||
}
|
||||
|
||||
/// Matches against a glob
|
||||
pub enum Matcher {
|
||||
/// Plain regex matching.
|
||||
Regex(Regex),
|
||||
|
||||
/// Case-insensitive equality.
|
||||
Whole(String),
|
||||
|
||||
/// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word.
|
||||
Word { word: String, regex: Option<Regex> },
|
||||
}
|
||||
|
||||
impl Matcher {
|
||||
/// Checks if the glob matches the given haystack.
|
||||
pub fn is_match(&mut self, haystack: &str) -> Result<bool, Error> {
|
||||
// We want to to do case-insensitive matching, so we convert to
|
||||
// lowercase first.
|
||||
let haystack = haystack.to_lowercase();
|
||||
|
||||
match self {
|
||||
Matcher::Regex(regex) => Ok(regex.is_match(&haystack)),
|
||||
Matcher::Whole(whole) => Ok(whole == &haystack),
|
||||
Matcher::Word { word, regex } => {
|
||||
// If we're looking for a literal word, then we first check if
|
||||
// the haystack contains the word as a substring.
|
||||
if !haystack.contains(&*word) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// If it does contain the word as a substring, then we need to
|
||||
// check if it is an actual word by testing it against the regex.
|
||||
let regex = if let Some(regex) = regex {
|
||||
regex
|
||||
} else {
|
||||
let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?;
|
||||
regex.insert(compiled_regex)
|
||||
};
|
||||
|
||||
Ok(regex.is_match(&haystack))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_domain_from_id() {
|
||||
get_localpart_from_id("").unwrap_err();
|
||||
get_localpart_from_id(":").unwrap_err();
|
||||
get_localpart_from_id(":asd").unwrap_err();
|
||||
get_localpart_from_id("::as::asad").unwrap_err();
|
||||
|
||||
assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test");
|
||||
assert_eq!(get_localpart_from_id("@:").unwrap(), "");
|
||||
assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tset_glob() -> Result<(), Error> {
|
||||
assert_eq!(
|
||||
glob_to_regex("simple", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Asimple\z"
|
||||
);
|
||||
assert_eq!(
|
||||
glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Asimple.{0,}\z"
|
||||
);
|
||||
assert_eq!(
|
||||
glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Asimple.{1}\z"
|
||||
);
|
||||
assert_eq!(
|
||||
glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Asimple.{2,}\z"
|
||||
);
|
||||
assert_eq!(
|
||||
glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Asimple.{3}\z"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(),
|
||||
r"\Aescape\.\z"
|
||||
);
|
||||
|
||||
assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple"));
|
||||
assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples"));
|
||||
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples"));
|
||||
assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples"));
|
||||
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple"));
|
||||
|
||||
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple."));
|
||||
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple"));
|
||||
assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples"));
|
||||
|
||||
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test"));
|
||||
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union
|
||||
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
@ -35,3 +35,20 @@ class FilteredPushRules:
|
|||
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
|
||||
|
||||
def get_base_rule_ids() -> Collection[str]: ...
|
||||
|
||||
class PushRuleEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
flattened_keys: Mapping[str, str],
|
||||
room_member_count: int,
|
||||
sender_power_level: Optional[int],
|
||||
notification_power_levels: Mapping[str, int],
|
||||
relations: Mapping[str, Set[Tuple[str, str]]],
|
||||
relation_match_enabled: bool,
|
||||
): ...
|
||||
def run(
|
||||
self,
|
||||
push_rules: FilteredPushRules,
|
||||
user_id: Optional[str],
|
||||
display_name: Optional[str],
|
||||
) -> Collection[dict]: ...
|
||||
|
|
|
@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = {
|
|||
"redactions": ["have_censored"],
|
||||
"room_stats_state": ["is_federatable"],
|
||||
"local_media_repository": ["safe_from_quarantine"],
|
||||
"users": ["shadow_banned"],
|
||||
"users": ["shadow_banned", "approved"],
|
||||
"e2e_fallback_keys_json": ["used"],
|
||||
"access_tokens": ["used"],
|
||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||
|
|
|
@ -269,3 +269,14 @@ class PublicRoomsFilterFields:
|
|||
|
||||
GENERIC_SEARCH_TERM: Final = "generic_search_term"
|
||||
ROOM_TYPES: Final = "room_types"
|
||||
|
||||
|
||||
class ApprovalNoticeMedium:
|
||||
"""Identifier for the medium this server will use to serve notice of approval for a
|
||||
specific user's registration.
|
||||
|
||||
As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md
|
||||
"""
|
||||
|
||||
NONE = "org.matrix.msc3866.none"
|
||||
EMAIL = "org.matrix.msc3866.email"
|
||||
|
|
|
@ -106,6 +106,8 @@ class Codes(str, Enum):
|
|||
# Part of MSC3895.
|
||||
UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE"
|
||||
|
||||
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code and message string attributes.
|
||||
|
@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError):
|
|||
return cs_error(self.msg, self.errcode, **extra)
|
||||
|
||||
|
||||
class NotApprovedError(SynapseError):
|
||||
def __init__(
|
||||
self,
|
||||
msg: str,
|
||||
approval_notice_medium: str,
|
||||
):
|
||||
super().__init__(
|
||||
code=403,
|
||||
msg=msg,
|
||||
errcode=Codes.USER_AWAITING_APPROVAL,
|
||||
additional_fields={"approval_notice_medium": approval_notice_medium},
|
||||
)
|
||||
|
||||
|
||||
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
|
||||
"""Utility method for constructing an error response for client-server
|
||||
interactions.
|
||||
|
|
|
@ -14,10 +14,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class MSC3866Config:
|
||||
"""Configuration for MSC3866 (mandating approval for new users)"""
|
||||
|
||||
# Whether the base support for the approval process is enabled. This includes the
|
||||
# ability for administrators to check and update the approval of users, even if no
|
||||
# approval is currently required.
|
||||
enabled: bool = False
|
||||
# Whether to require that new users are approved by an admin before their account
|
||||
# can be used. Note that this setting is ignored if 'enabled' is false.
|
||||
require_approval_for_new_accounts: bool = False
|
||||
|
||||
|
||||
class ExperimentalConfig(Config):
|
||||
"""Config section for enabling experimental features"""
|
||||
|
||||
|
@ -97,6 +112,10 @@ class ExperimentalConfig(Config):
|
|||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||
|
||||
# MSC3866: M_USER_AWAITING_APPROVAL error code
|
||||
raw_msc3866_config = experimental.get("msc3866", {})
|
||||
self.msc3866 = MSC3866Config(**raw_msc3866_config)
|
||||
|
||||
# MSC3881: Remotely toggle push notifications for another client
|
||||
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||
|
||||
|
|
|
@ -289,6 +289,10 @@ class _EventInternalMetadata:
|
|||
"""
|
||||
return self._dict.get("historical", False)
|
||||
|
||||
def is_notifiable(self) -> bool:
|
||||
"""Whether this event can trigger a push notification"""
|
||||
return not self.is_outlier() or self.is_out_of_band_membership()
|
||||
|
||||
|
||||
class EventBase(metaclass=abc.ABCMeta):
|
||||
@property
|
||||
|
|
|
@ -32,6 +32,7 @@ class AdminHandler:
|
|||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||
|
||||
async def get_whois(self, user: UserID) -> JsonDict:
|
||||
connections = []
|
||||
|
@ -75,6 +76,10 @@ class AdminHandler:
|
|||
"is_guest",
|
||||
}
|
||||
|
||||
if self._msc3866_enabled:
|
||||
# Only include the approved flag if support for MSC3866 is enabled.
|
||||
user_info_to_return.add("approved")
|
||||
|
||||
# Restrict returned keys to a known set.
|
||||
user_info_dict = {
|
||||
key: value
|
||||
|
|
|
@ -1009,6 +1009,17 @@ class AuthHandler:
|
|||
return res[0]
|
||||
return None
|
||||
|
||||
async def is_user_approved(self, user_id: str) -> bool:
|
||||
"""Checks if a user is approved and therefore can be allowed to log in.
|
||||
|
||||
Args:
|
||||
user_id: the user to check the approval status of.
|
||||
|
||||
Returns:
|
||||
A boolean that is True if the user is approved, False otherwise.
|
||||
"""
|
||||
return await self.store.is_user_approved(user_id)
|
||||
|
||||
async def _find_user_id_and_pwd_hash(
|
||||
self, user_id: str
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
|
|
|
@ -273,11 +273,9 @@ class DeviceWorkerHandler:
|
|||
possibly_left = possibly_changed | possibly_left
|
||||
|
||||
# Double check if we still share rooms with the given user.
|
||||
users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
|
||||
possibly_left
|
||||
)
|
||||
users_rooms = await self.store.get_rooms_for_users(possibly_left)
|
||||
for changed_user_id, entries in users_rooms.items():
|
||||
if any(e.room_id in room_ids for e in entries):
|
||||
if any(rid in room_ids for rid in entries):
|
||||
possibly_left.discard(changed_user_id)
|
||||
else:
|
||||
possibly_joined.discard(changed_user_id)
|
||||
|
@ -309,6 +307,17 @@ class DeviceWorkerHandler:
|
|||
"self_signing_key": self_signing_key,
|
||||
}
|
||||
|
||||
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||
"""Handles sending appropriate device list updates in a room that has
|
||||
gone from partial to full state.
|
||||
"""
|
||||
|
||||
# TODO(faster_joins): worker mode support
|
||||
# https://github.com/matrix-org/synapse/issues/12994
|
||||
logger.error(
|
||||
"Trying handling device list state for partial join: not supported on workers."
|
||||
)
|
||||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
|
@ -746,6 +755,95 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
finally:
|
||||
self._handle_new_device_update_is_processing = False
|
||||
|
||||
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||
"""Handles sending appropriate device list updates in a room that has
|
||||
gone from partial to full state.
|
||||
"""
|
||||
|
||||
# We defer to the device list updater to handle pending remote device
|
||||
# list updates.
|
||||
await self.device_list_updater.handle_room_un_partial_stated(room_id)
|
||||
|
||||
# Replay local updates.
|
||||
(
|
||||
join_event_id,
|
||||
device_lists_stream_id,
|
||||
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||
room_id
|
||||
)
|
||||
|
||||
# Get the local device list changes that have happened in the room since
|
||||
# we started joining. If there are no updates there's nothing left to do.
|
||||
changes = await self.store.get_device_list_changes_in_room(
|
||||
room_id, device_lists_stream_id
|
||||
)
|
||||
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
|
||||
if not local_changes:
|
||||
return
|
||||
|
||||
# Note: We have persisted the full state at this point, we just haven't
|
||||
# cleared the `partial_room` flag.
|
||||
join_state_ids = await self._state_storage.get_state_ids_for_event(
|
||||
join_event_id, await_full_state=False
|
||||
)
|
||||
current_state_ids = await self.store.get_partial_current_state_ids(room_id)
|
||||
|
||||
# Now we need to work out all servers that might have been in the room
|
||||
# at any point during our join.
|
||||
|
||||
# First we look for any membership states that have changed between the
|
||||
# initial join and now...
|
||||
all_keys = set(join_state_ids)
|
||||
all_keys.update(current_state_ids)
|
||||
|
||||
potentially_changed_hosts = set()
|
||||
for etype, state_key in all_keys:
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
|
||||
prev = join_state_ids.get((etype, state_key))
|
||||
current = current_state_ids.get((etype, state_key))
|
||||
|
||||
if prev != current:
|
||||
potentially_changed_hosts.add(get_domain_from_id(state_key))
|
||||
|
||||
# ... then we add all the hosts that are currently joined to the room...
|
||||
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
|
||||
potentially_changed_hosts.update(current_hosts_in_room)
|
||||
|
||||
# ... and finally we remove any hosts that we were told about, as we
|
||||
# will have sent device list updates to those hosts when they happened.
|
||||
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
|
||||
room_id
|
||||
)
|
||||
potentially_changed_hosts.difference_update(known_hosts_at_join)
|
||||
|
||||
potentially_changed_hosts.discard(self.server_name)
|
||||
|
||||
if not potentially_changed_hosts:
|
||||
# Nothing to do.
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Found %d changed hosts to send device list updates to",
|
||||
len(potentially_changed_hosts),
|
||||
)
|
||||
|
||||
for user_id, device_id in local_changes:
|
||||
await self.store.add_device_list_outbound_pokes(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
room_id=room_id,
|
||||
stream_id=None,
|
||||
hosts=potentially_changed_hosts,
|
||||
context=None,
|
||||
)
|
||||
|
||||
# Notify things that device lists need to be sent out.
|
||||
self.notifier.notify_replication()
|
||||
for host in potentially_changed_hosts:
|
||||
self.federation_sender.send_device_messages(host, immediate=False)
|
||||
|
||||
|
||||
def _update_device_from_client_ips(
|
||||
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
|
||||
|
@ -836,6 +934,16 @@ class DeviceListUpdater:
|
|||
)
|
||||
return
|
||||
|
||||
# Check if we are partially joining any rooms. If so we need to store
|
||||
# all device list updates so that we can handle them correctly once we
|
||||
# know who is in the room.
|
||||
partial_rooms = await self.store.get_partial_state_rooms_and_servers()
|
||||
if partial_rooms:
|
||||
await self.store.add_remote_device_list_to_pending(
|
||||
user_id,
|
||||
device_id,
|
||||
)
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We don't share any rooms with this user. Ignore update, as we
|
||||
|
@ -1175,3 +1283,35 @@ class DeviceListUpdater:
|
|||
device_ids.append(verify_key.version)
|
||||
|
||||
return device_ids
|
||||
|
||||
async def handle_room_un_partial_stated(self, room_id: str) -> None:
|
||||
"""Handles sending appropriate device list updates in a room that has
|
||||
gone from partial to full state.
|
||||
"""
|
||||
|
||||
pending_updates = (
|
||||
await self.store.get_pending_remote_device_list_updates_for_room(room_id)
|
||||
)
|
||||
|
||||
for user_id, device_id in pending_updates:
|
||||
logger.info(
|
||||
"Got pending device list update in room %s: %s / %s",
|
||||
room_id,
|
||||
user_id,
|
||||
device_id,
|
||||
)
|
||||
position = await self.store.add_device_change_to_streams(
|
||||
user_id,
|
||||
[device_id],
|
||||
room_ids=[room_id],
|
||||
)
|
||||
|
||||
if not position:
|
||||
# This should only happen if there are no updates, which
|
||||
# shouldn't happen when we've passed in a non-empty set of
|
||||
# device IDs.
|
||||
continue
|
||||
|
||||
self.device_handler.notifier.on_new_event(
|
||||
StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
|
||||
)
|
||||
|
|
|
@ -38,7 +38,7 @@ from signedjson.sign import verify_signed_json
|
|||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
CodeMessageException,
|
||||
|
@ -149,6 +149,8 @@ class FederationHandler:
|
|||
self.http_client = hs.get_proxied_blacklisted_http_client()
|
||||
self._replication = hs.get_replication_data_handler()
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
|
||||
|
||||
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
|
||||
hs
|
||||
|
@ -209,7 +211,7 @@ class FederationHandler:
|
|||
current_depth: int,
|
||||
limit: int,
|
||||
*,
|
||||
processing_start_time: int,
|
||||
processing_start_time: Optional[int],
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether the `current_depth` is at or approaching any backfill
|
||||
|
@ -221,12 +223,23 @@ class FederationHandler:
|
|||
room_id: The room to backfill in.
|
||||
current_depth: The depth to check at for any upcoming backfill points.
|
||||
limit: The max number of events to request from the remote federated server.
|
||||
processing_start_time: The time when `maybe_backfill` started
|
||||
processing. Only used for timing.
|
||||
processing_start_time: The time when `maybe_backfill` started processing.
|
||||
Only used for timing. If `None`, no timing observation will be made.
|
||||
"""
|
||||
backwards_extremities = [
|
||||
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
|
||||
for event_id, depth in await self.store.get_backfill_points_in_room(room_id)
|
||||
for event_id, depth in await self.store.get_backfill_points_in_room(
|
||||
room_id=room_id,
|
||||
current_depth=current_depth,
|
||||
# We only need to end up with 5 extremities combined with the
|
||||
# insertion event extremities to make the `/backfill` request
|
||||
# but fetch an order of magnitude more to make sure there is
|
||||
# enough even after we filter them by whether visible in the
|
||||
# history. This isn't fool-proof as all backfill points within
|
||||
# our limit could be filtered out but seems like a good amount
|
||||
# to try with at least.
|
||||
limit=50,
|
||||
)
|
||||
]
|
||||
|
||||
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
|
||||
|
@ -234,7 +247,12 @@ class FederationHandler:
|
|||
insertion_events_to_be_backfilled = [
|
||||
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
|
||||
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id
|
||||
room_id=room_id,
|
||||
current_depth=current_depth,
|
||||
# We only need to end up with 5 extremities combined with
|
||||
# the backfill points to make the `/backfill` request ...
|
||||
# (see the other comment above for more context).
|
||||
limit=50,
|
||||
)
|
||||
]
|
||||
logger.debug(
|
||||
|
@ -243,10 +261,6 @@ class FederationHandler:
|
|||
insertion_events_to_be_backfilled,
|
||||
)
|
||||
|
||||
if not backwards_extremities and not insertion_events_to_be_backfilled:
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
return False
|
||||
|
||||
# we now have a list of potential places to backpaginate from. We prefer to
|
||||
# start with the most recent (ie, max depth), so let's sort the list.
|
||||
sorted_backfill_points: List[_BackfillPoint] = sorted(
|
||||
|
@ -267,6 +281,33 @@ class FederationHandler:
|
|||
sorted_backfill_points,
|
||||
)
|
||||
|
||||
# If we have no backfill points lower than the `current_depth` then
|
||||
# either we can a) bail or b) still attempt to backfill. We opt to try
|
||||
# backfilling anyway just in case we do get relevant events.
|
||||
if not sorted_backfill_points and current_depth != MAX_DEPTH:
|
||||
logger.debug(
|
||||
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
|
||||
)
|
||||
return await self._maybe_backfill_inner(
|
||||
room_id=room_id,
|
||||
# We use `MAX_DEPTH` so that we find all backfill points next
|
||||
# time (all events are below the `MAX_DEPTH`)
|
||||
current_depth=MAX_DEPTH,
|
||||
limit=limit,
|
||||
# We don't want to start another timing observation from this
|
||||
# nested recursive call. The top-most call can record the time
|
||||
# overall otherwise the smaller one will throw off the results.
|
||||
processing_start_time=None,
|
||||
)
|
||||
|
||||
# Even after recursing with `MAX_DEPTH`, we didn't find any
|
||||
# backward extremities to backfill from.
|
||||
if not sorted_backfill_points:
|
||||
logger.debug(
|
||||
"_maybe_backfill_inner: Not backfilling as no backward extremeties found."
|
||||
)
|
||||
return False
|
||||
|
||||
# If we're approaching an extremity we trigger a backfill, otherwise we
|
||||
# no-op.
|
||||
#
|
||||
|
@ -276,47 +317,16 @@ class FederationHandler:
|
|||
# chose more than one times the limit in case of failure, but choosing a
|
||||
# much larger factor will result in triggering a backfill request much
|
||||
# earlier than necessary.
|
||||
#
|
||||
# XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
|
||||
# care about events that have happened after our current position.
|
||||
#
|
||||
max_depth = sorted_backfill_points[0].depth
|
||||
if current_depth - 2 * limit > max_depth:
|
||||
max_depth_of_backfill_points = sorted_backfill_points[0].depth
|
||||
if current_depth - 2 * limit > max_depth_of_backfill_points:
|
||||
logger.debug(
|
||||
"Not backfilling as we don't need to. %d < %d - 2 * %d",
|
||||
max_depth,
|
||||
max_depth_of_backfill_points,
|
||||
current_depth,
|
||||
limit,
|
||||
)
|
||||
return False
|
||||
|
||||
# We ignore extremities that have a greater depth than our current depth
|
||||
# as:
|
||||
# 1. we don't really care about getting events that have happened
|
||||
# after our current position; and
|
||||
# 2. we have likely previously tried and failed to backfill from that
|
||||
# extremity, so to avoid getting "stuck" requesting the same
|
||||
# backfill repeatedly we drop those extremities.
|
||||
#
|
||||
# However, we need to check that the filtered extremities are non-empty.
|
||||
# If they are empty then either we can a) bail or b) still attempt to
|
||||
# backfill. We opt to try backfilling anyway just in case we do get
|
||||
# relevant events.
|
||||
#
|
||||
filtered_sorted_backfill_points = [
|
||||
t for t in sorted_backfill_points if t.depth <= current_depth
|
||||
]
|
||||
if filtered_sorted_backfill_points:
|
||||
logger.debug(
|
||||
"_maybe_backfill_inner: backfill points before current depth: %s",
|
||||
filtered_sorted_backfill_points,
|
||||
)
|
||||
sorted_backfill_points = filtered_sorted_backfill_points
|
||||
else:
|
||||
logger.debug(
|
||||
"_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
|
||||
)
|
||||
|
||||
# For performance's sake, we only want to paginate from a particular extremity
|
||||
# if we can actually see the events we'll get. Otherwise, we'd just spend a lot
|
||||
# of resources to get redacted events. We check each extremity in turn and
|
||||
|
@ -402,11 +412,22 @@ class FederationHandler:
|
|||
# First we try hosts that are already in the room.
|
||||
# TODO: HEURISTIC ALERT.
|
||||
likely_domains = (
|
||||
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
|
||||
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
|
||||
async def try_backfill(domains: Collection[str]) -> bool:
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
|
||||
# Number of contacted remote homeservers that have denied our backfill
|
||||
# request with a 4xx code.
|
||||
denied_count = 0
|
||||
|
||||
# Maximum number of contacted remote homeservers that can deny our
|
||||
# backfill request with 4xx codes before we give up.
|
||||
max_denied_count = 5
|
||||
|
||||
for dom in domains:
|
||||
# We don't want to ask our own server for information we don't have
|
||||
if dom == self.server_name:
|
||||
|
@ -425,13 +446,33 @@ class FederationHandler:
|
|||
continue
|
||||
except HttpResponseException as e:
|
||||
if 400 <= e.code < 500:
|
||||
raise e.to_synapse_error()
|
||||
logger.warning(
|
||||
"Backfill denied from %s because %s [%d/%d]",
|
||||
dom,
|
||||
e,
|
||||
denied_count,
|
||||
max_denied_count,
|
||||
)
|
||||
denied_count += 1
|
||||
if denied_count >= max_denied_count:
|
||||
return False
|
||||
continue
|
||||
|
||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||
continue
|
||||
except CodeMessageException as e:
|
||||
if 400 <= e.code < 500:
|
||||
raise
|
||||
logger.warning(
|
||||
"Backfill denied from %s because %s [%d/%d]",
|
||||
dom,
|
||||
e,
|
||||
denied_count,
|
||||
max_denied_count,
|
||||
)
|
||||
denied_count += 1
|
||||
if denied_count >= max_denied_count:
|
||||
return False
|
||||
continue
|
||||
|
||||
logger.info("Failed to backfill from %s because %s", dom, e)
|
||||
continue
|
||||
|
@ -450,10 +491,15 @@ class FederationHandler:
|
|||
|
||||
return False
|
||||
|
||||
processing_end_time = self.clock.time_msec()
|
||||
backfill_processing_before_timer.observe(
|
||||
(processing_end_time - processing_start_time) / 1000
|
||||
)
|
||||
# If we have the `processing_start_time`, then we can make an
|
||||
# observation. We wouldn't have the `processing_start_time` in the case
|
||||
# where `_maybe_backfill_inner` is recursively called to find any
|
||||
# backfill points regardless of `current_depth`.
|
||||
if processing_start_time is not None:
|
||||
processing_end_time = self.clock.time_msec()
|
||||
backfill_processing_before_timer.observe(
|
||||
(processing_end_time - processing_start_time) / 1000
|
||||
)
|
||||
|
||||
success = await try_backfill(likely_domains)
|
||||
if success:
|
||||
|
@ -956,9 +1002,15 @@ class FederationHandler:
|
|||
)
|
||||
|
||||
context = EventContext.for_outlier(self._storage_controllers)
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
|
||||
await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
|
||||
try:
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
except Exception:
|
||||
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||
raise
|
||||
|
||||
return event
|
||||
|
||||
|
@ -1624,6 +1676,9 @@ class FederationHandler:
|
|||
# https://github.com/matrix-org/synapse/issues/12994
|
||||
await self.state_handler.update_current_state(room_id)
|
||||
|
||||
logger.info("Handling any pending device list updates")
|
||||
await self._device_handler.handle_room_un_partial_stated(room_id)
|
||||
|
||||
logger.info("Clearing partial-state flag for %s", room_id)
|
||||
success = await self.store.clear_partial_state_room(room_id)
|
||||
if success:
|
||||
|
|
|
@ -2170,6 +2170,7 @@ class FederationEventHandler:
|
|||
if instance != self._instance_name:
|
||||
# Limit the number of events sent over replication. We choose 200
|
||||
# here as that is what we default to in `max_request_body_size(..)`
|
||||
result = {}
|
||||
try:
|
||||
for batch in batch_iter(event_and_contexts, 200):
|
||||
result = await self._send_events(
|
||||
|
|
|
@ -220,6 +220,7 @@ class RegistrationHandler:
|
|||
by_admin: bool = False,
|
||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
approved: bool = False,
|
||||
) -> str:
|
||||
"""Registers a new client on the server.
|
||||
|
||||
|
@ -246,6 +247,8 @@ class RegistrationHandler:
|
|||
user_agent_ips: Tuples of user-agents and IP addresses used
|
||||
during the registration process.
|
||||
auth_provider_id: The SSO IdP the user used, if any.
|
||||
approved: True if the new user should be considered already
|
||||
approved by an administrator.
|
||||
Returns:
|
||||
The registered user_id.
|
||||
Raises:
|
||||
|
@ -307,6 +310,7 @@ class RegistrationHandler:
|
|||
user_type=user_type,
|
||||
address=address,
|
||||
shadow_banned=shadow_banned,
|
||||
approved=approved,
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(localpart)
|
||||
|
@ -695,6 +699,7 @@ class RegistrationHandler:
|
|||
user_type: Optional[str] = None,
|
||||
address: Optional[str] = None,
|
||||
shadow_banned: bool = False,
|
||||
approved: bool = False,
|
||||
) -> None:
|
||||
"""Register user in the datastore.
|
||||
|
||||
|
@ -713,6 +718,7 @@ class RegistrationHandler:
|
|||
api.constants.UserTypes, or None for a normal user.
|
||||
address: the IP address used to perform the registration.
|
||||
shadow_banned: Whether to shadow-ban the user
|
||||
approved: Whether to mark the user as approved by an administrator
|
||||
"""
|
||||
if self.hs.config.worker.worker_app:
|
||||
await self._register_client(
|
||||
|
@ -726,6 +732,7 @@ class RegistrationHandler:
|
|||
user_type=user_type,
|
||||
address=address,
|
||||
shadow_banned=shadow_banned,
|
||||
approved=approved,
|
||||
)
|
||||
else:
|
||||
await self.store.register_user(
|
||||
|
@ -738,6 +745,7 @@ class RegistrationHandler:
|
|||
admin=admin,
|
||||
user_type=user_type,
|
||||
shadow_banned=shadow_banned,
|
||||
approved=approved,
|
||||
)
|
||||
|
||||
# Only call the account validity module(s) on the main process, to avoid
|
||||
|
|
|
@ -1540,7 +1540,9 @@ class TimestampLookupHandler:
|
|||
)
|
||||
|
||||
likely_domains = (
|
||||
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
|
||||
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
|
||||
# Loop through each homeserver candidate until we get a succesful response
|
||||
|
|
|
@ -187,6 +187,19 @@ class SendEmailHandler:
|
|||
multipart_msg["To"] = email_address
|
||||
multipart_msg["Date"] = email.utils.formatdate()
|
||||
multipart_msg["Message-ID"] = email.utils.make_msgid()
|
||||
# Discourage automatic responses to Synapse's emails.
|
||||
# Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted"
|
||||
# header is present with any value other than "no". See
|
||||
# https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1
|
||||
multipart_msg["Auto-Submitted"] = "auto-generated"
|
||||
# Also include a Microsoft-Exchange specific header:
|
||||
# https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1
|
||||
# which suggests it can take the value "All" to "suppress all auto-replies",
|
||||
# or a comma separated list of auto-reply classes to suppress.
|
||||
# The following stack overflow question has a little more context:
|
||||
# https://stackoverflow.com/a/25324691/5252017
|
||||
# https://stackoverflow.com/a/61646381/5252017
|
||||
multipart_msg["X-Auto-Response-Suppress"] = "All"
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
|
|
|
@ -1490,16 +1490,14 @@ class SyncHandler:
|
|||
since_token.device_list_key
|
||||
)
|
||||
if changed_users is not None:
|
||||
result = await self.store.get_rooms_for_users_with_stream_ordering(
|
||||
changed_users
|
||||
)
|
||||
result = await self.store.get_rooms_for_users(changed_users)
|
||||
|
||||
for changed_user_id, entries in result.items():
|
||||
# Check if the changed user shares any rooms with the user,
|
||||
# or if the changed user is the syncing user (as we always
|
||||
# want to include device list updates of their own devices).
|
||||
if user_id == changed_user_id or any(
|
||||
e.room_id in joined_rooms for e in entries
|
||||
rid in joined_rooms for rid in entries
|
||||
):
|
||||
users_that_have_changed.add(changed_user_id)
|
||||
else:
|
||||
|
@ -1533,13 +1531,9 @@ class SyncHandler:
|
|||
newly_left_users.update(left_users)
|
||||
|
||||
# Remove any users that we still share a room with.
|
||||
left_users_rooms = (
|
||||
await self.store.get_rooms_for_users_with_stream_ordering(
|
||||
newly_left_users
|
||||
)
|
||||
)
|
||||
left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
|
||||
for user_id, entries in left_users_rooms.items():
|
||||
if any(e.room_id in joined_rooms for e in entries):
|
||||
if any(rid in joined_rooms for rid in entries):
|
||||
newly_left_users.discard(user_id)
|
||||
|
||||
return DeviceListUpdates(
|
||||
|
|
|
@ -842,6 +842,8 @@ class ModuleApi:
|
|||
however invalidation that needs to go to other workers needs to call `invalidate_cache`
|
||||
on the module API instead.
|
||||
|
||||
Added in Synapse v1.69.0.
|
||||
|
||||
Args:
|
||||
cached_function: The cached function that will be registered to receive invalidation
|
||||
locally and from other workers.
|
||||
|
@ -856,6 +858,8 @@ class ModuleApi:
|
|||
"""Invalidate a cache entry of a cached function across workers. The cached function
|
||||
needs to be registered on all workers first with `register_cached_function`.
|
||||
|
||||
Added in Synapse v1.69.0.
|
||||
|
||||
Args:
|
||||
cached_function: The cached function that needs an invalidation
|
||||
keys: keys of the entry to invalidate, usually matching the arguments of the
|
||||
|
|
|
@ -17,6 +17,7 @@ import itertools
|
|||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
|
@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext
|
|||
from synapse.state import POWER_KEY
|
||||
from synapse.storage.databases.main.roommember import EventIdMembership
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_event_for_clients_with_state
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
@ -173,7 +172,11 @@ class BulkPushRuleEvaluator:
|
|||
|
||||
async def _get_power_levels_and_sender_level(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[dict, int]:
|
||||
) -> Tuple[dict, Optional[int]]:
|
||||
# There are no power levels and sender levels possible to get from outlier
|
||||
if event.internal_metadata.is_outlier():
|
||||
return {}, None
|
||||
|
||||
event_types = auth_types_for_event(event.room_version, event)
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types(event_types)
|
||||
|
@ -250,8 +253,8 @@ class BulkPushRuleEvaluator:
|
|||
should increment the unread count, and insert the results into the
|
||||
event_push_actions_staging table.
|
||||
"""
|
||||
if event.internal_metadata.is_outlier():
|
||||
# This can happen due to out of band memberships
|
||||
if not event.internal_metadata.is_notifiable():
|
||||
# Push rules for events that aren't notifiable can't be processed by this
|
||||
return
|
||||
|
||||
# Disable counting as unread unless the experimental configuration is
|
||||
|
@ -286,11 +289,11 @@ class BulkPushRuleEvaluator:
|
|||
if relation.rel_type == RelationTypes.THREAD:
|
||||
thread_id = relation.parent_id
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event,
|
||||
evaluator = PushRuleEvaluator(
|
||||
_flatten_dict(event),
|
||||
room_member_count,
|
||||
sender_power_level,
|
||||
power_levels,
|
||||
power_levels.get("notifications", {}),
|
||||
relations,
|
||||
self._relations_match_enabled,
|
||||
)
|
||||
|
@ -300,20 +303,10 @@ class BulkPushRuleEvaluator:
|
|||
event.room_id, users
|
||||
)
|
||||
|
||||
# This is a check for the case where user joins a room without being
|
||||
# allowed to see history, and then the server receives a delayed event
|
||||
# from before the user joined, which they should not be pushed for
|
||||
uids_with_visibility = await filter_event_for_clients_with_state(
|
||||
self.store, users, event, context
|
||||
)
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
continue
|
||||
|
||||
if uid not in uids_with_visibility:
|
||||
continue
|
||||
|
||||
display_name = None
|
||||
profile = profiles.get(uid)
|
||||
if profile:
|
||||
|
@ -334,17 +327,25 @@ class BulkPushRuleEvaluator:
|
|||
# current user, it'll be added to the dict later.
|
||||
actions_by_user[uid] = []
|
||||
|
||||
for rule, enabled in rules.rules():
|
||||
if not enabled:
|
||||
continue
|
||||
actions = evaluator.run(rules, uid, display_name)
|
||||
if "notify" in actions:
|
||||
# Push rules say we should notify the user of this event
|
||||
actions_by_user[uid] = actions
|
||||
|
||||
matches = evaluator.check_conditions(rule.conditions, uid, display_name)
|
||||
if matches:
|
||||
actions = [x for x in rule.actions if x != "dont_notify"]
|
||||
if actions and "notify" in actions:
|
||||
# Push rules say we should notify the user of this event
|
||||
actions_by_user[uid] = actions
|
||||
break
|
||||
# This is a check for the case where user joins a room without being
|
||||
# allowed to see history, and then the server receives a delayed event
|
||||
# from before the user joined, which they should not be pushed for
|
||||
#
|
||||
# We do this *after* calculating the push actions as a) its unlikely
|
||||
# that we'll filter anyone out and b) for large rooms its likely that
|
||||
# most users will have push disabled and so the set of users to check is
|
||||
# much smaller.
|
||||
uids_with_visibility = await filter_event_for_clients_with_state(
|
||||
self.store, actions_by_user.keys(), event, context
|
||||
)
|
||||
|
||||
for user_id in set(actions_by_user).difference(uids_with_visibility):
|
||||
actions_by_user.pop(user_id, None)
|
||||
|
||||
# Mark in the DB staging area the push actions for users who should be
|
||||
# notified for this event. (This will then get handled when we persist
|
||||
|
@ -361,3 +362,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
|
|||
Rule = Dict[str, dict]
|
||||
RulesByUser = Dict[str, List[Rule]]
|
||||
StateGroup = Union[object, int]
|
||||
|
||||
|
||||
def _flatten_dict(
|
||||
d: Union[EventBase, Mapping[str, Any]],
|
||||
prefix: Optional[List[str]] = None,
|
||||
result: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
if prefix is None:
|
||||
prefix = []
|
||||
if result is None:
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, str):
|
||||
result[".".join(prefix + [key])] = value.lower()
|
||||
elif isinstance(value, Mapping):
|
||||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -102,10 +102,8 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
|
|||
# with PRIORITY_CLASS_INVERSE_MAP.
|
||||
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
||||
|
||||
if unscoped_rule_id:
|
||||
templaterule["rule_id"] = unscoped_rule_id
|
||||
if rule.default:
|
||||
templaterule["default"] = True
|
||||
templaterule["rule_id"] = unscoped_rule_id
|
||||
templaterule["default"] = rule.default
|
||||
return templaterule
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
||||
from synapse.storage.databases.main.event_push_actions import HttpPushAction
|
||||
|
||||
from . import push_rule_evaluator, push_tools
|
||||
from . import push_tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -56,6 +56,39 @@ http_badges_failed_counter = Counter(
|
|||
)
|
||||
|
||||
|
||||
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a list of actions into a `tweaks` dict (which can then be passed to
|
||||
the push gateway).
|
||||
|
||||
This function ignores all actions other than `set_tweak` actions, and treats
|
||||
absent `value`s as `True`, which agrees with the only spec-defined treatment
|
||||
of absent `value`s (namely, for `highlight` tweaks).
|
||||
|
||||
Args:
|
||||
actions: list of actions
|
||||
e.g. [
|
||||
{"set_tweak": "a", "value": "AAA"},
|
||||
{"set_tweak": "b", "value": "BBB"},
|
||||
{"set_tweak": "highlight"},
|
||||
"notify"
|
||||
]
|
||||
|
||||
Returns:
|
||||
dictionary of tweaks for those actions
|
||||
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
|
||||
"""
|
||||
tweaks = {}
|
||||
for a in actions:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if "set_tweak" in a:
|
||||
# value is allowed to be absent in which case the value assumed
|
||||
# should be True.
|
||||
tweaks[a["set_tweak"]] = a.get("value", True)
|
||||
return tweaks
|
||||
|
||||
|
||||
class HttpPusher(Pusher):
|
||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||
MAX_BACKOFF_SEC = 60 * 60
|
||||
|
@ -286,7 +319,7 @@ class HttpPusher(Pusher):
|
|||
if "notify" not in push_action.actions:
|
||||
return True
|
||||
|
||||
tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
|
||||
tweaks = tweaks_for_actions(push_action.actions)
|
||||
badge = await push_tools.get_badge_count(
|
||||
self.hs.get_datastores().main,
|
||||
self.user_id,
|
||||
|
|
|
@ -1,361 +0,0 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Pattern,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from matrix_common.regex import glob_to_regex, to_word_pattern
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
|
||||
IS_GLOB = re.compile(r"[\?\*\[\]]")
|
||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
def _room_member_count(
|
||||
ev: EventBase, condition: Mapping[str, Any], room_member_count: int
|
||||
) -> bool:
|
||||
return _test_ineq_condition(condition, room_member_count)
|
||||
|
||||
|
||||
def _sender_notification_permission(
|
||||
ev: EventBase,
|
||||
condition: Mapping[str, Any],
|
||||
sender_power_level: int,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
) -> bool:
|
||||
notif_level_key = condition.get("key")
|
||||
if notif_level_key is None:
|
||||
return False
|
||||
|
||||
notif_levels = power_levels.get("notifications", {})
|
||||
assert isinstance(notif_levels, dict)
|
||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||
|
||||
return sender_power_level >= room_notif_level
|
||||
|
||||
|
||||
def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
|
||||
if "is" not in condition:
|
||||
return False
|
||||
m = INEQUALITY_EXPR.match(condition["is"])
|
||||
if not m:
|
||||
return False
|
||||
ineq = m.group(1)
|
||||
rhs = m.group(2)
|
||||
if not rhs.isdigit():
|
||||
return False
|
||||
rhs_int = int(rhs)
|
||||
|
||||
if ineq == "" or ineq == "==":
|
||||
return number == rhs_int
|
||||
elif ineq == "<":
|
||||
return number < rhs_int
|
||||
elif ineq == ">":
|
||||
return number > rhs_int
|
||||
elif ineq == ">=":
|
||||
return number >= rhs_int
|
||||
elif ineq == "<=":
|
||||
return number <= rhs_int
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a list of actions into a `tweaks` dict (which can then be passed to
|
||||
the push gateway).
|
||||
|
||||
This function ignores all actions other than `set_tweak` actions, and treats
|
||||
absent `value`s as `True`, which agrees with the only spec-defined treatment
|
||||
of absent `value`s (namely, for `highlight` tweaks).
|
||||
|
||||
Args:
|
||||
actions: list of actions
|
||||
e.g. [
|
||||
{"set_tweak": "a", "value": "AAA"},
|
||||
{"set_tweak": "b", "value": "BBB"},
|
||||
{"set_tweak": "highlight"},
|
||||
"notify"
|
||||
]
|
||||
|
||||
Returns:
|
||||
dictionary of tweaks for those actions
|
||||
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
|
||||
"""
|
||||
tweaks = {}
|
||||
for a in actions:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if "set_tweak" in a:
|
||||
# value is allowed to be absent in which case the value assumed
|
||||
# should be True.
|
||||
tweaks[a["set_tweak"]] = a.get("value", True)
|
||||
return tweaks
|
||||
|
||||
|
||||
class PushRuleEvaluatorForEvent:
|
||||
def __init__(
|
||||
self,
|
||||
event: EventBase,
|
||||
room_member_count: int,
|
||||
sender_power_level: int,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
relations: Dict[str, Set[Tuple[str, str]]],
|
||||
relations_match_enabled: bool,
|
||||
):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
self._sender_power_level = sender_power_level
|
||||
self._power_levels = power_levels
|
||||
self._relations = relations
|
||||
self._relations_match_enabled = relations_match_enabled
|
||||
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
# Maps cache keys to final values.
|
||||
self._condition_cache: Dict[str, bool] = {}
|
||||
|
||||
def check_conditions(
|
||||
self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if a user's conditions/user ID/display name match the event.
|
||||
|
||||
Args:
|
||||
conditions: The user's conditions to match.
|
||||
uid: The user's MXID.
|
||||
display_name: The display name.
|
||||
|
||||
Returns:
|
||||
True if all conditions match the event, False otherwise.
|
||||
"""
|
||||
for cond in conditions:
|
||||
_cache_key = cond.get("_cache_key", None)
|
||||
if _cache_key:
|
||||
res = self._condition_cache.get(_cache_key, None)
|
||||
if res is False:
|
||||
return False
|
||||
elif res is True:
|
||||
continue
|
||||
|
||||
res = self.matches(cond, uid, display_name)
|
||||
if _cache_key:
|
||||
self._condition_cache[_cache_key] = bool(res)
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def matches(
|
||||
self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if a user's condition/user ID/display name match the event.
|
||||
|
||||
Args:
|
||||
condition: The user's condition to match.
|
||||
uid: The user's MXID.
|
||||
display_name: The display name, or None if there is not one.
|
||||
|
||||
Returns:
|
||||
True if the condition matches the event, False otherwise.
|
||||
"""
|
||||
if condition["kind"] == "event_match":
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition["kind"] == "contains_display_name":
|
||||
return self._contains_display_name(display_name)
|
||||
elif condition["kind"] == "room_member_count":
|
||||
return _room_member_count(self._event, condition, self._room_member_count)
|
||||
elif condition["kind"] == "sender_notification_permission":
|
||||
return _sender_notification_permission(
|
||||
self._event, condition, self._sender_power_level, self._power_levels
|
||||
)
|
||||
elif (
|
||||
condition["kind"] == "org.matrix.msc3772.relation_match"
|
||||
and self._relations_match_enabled
|
||||
):
|
||||
return self._relation_match(condition, user_id)
|
||||
else:
|
||||
# XXX This looks incorrect -- we have reached an unknown condition
|
||||
# kind and are unconditionally returning that it matches. Note
|
||||
# that it seems possible to provide a condition to the /pushrules
|
||||
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
|
||||
return True
|
||||
|
||||
def _event_match(self, condition: Mapping, user_id: str) -> bool:
|
||||
"""
|
||||
Check an "event_match" push rule condition.
|
||||
|
||||
Args:
|
||||
condition: The "event_match" push rule condition to match.
|
||||
user_id: The user's MXID.
|
||||
|
||||
Returns:
|
||||
True if the condition matches the event, False otherwise.
|
||||
"""
|
||||
pattern = condition.get("pattern", None)
|
||||
|
||||
if not pattern:
|
||||
pattern_type = condition.get("pattern_type", None)
|
||||
if pattern_type == "user_id":
|
||||
pattern = user_id
|
||||
elif pattern_type == "user_localpart":
|
||||
pattern = UserID.from_string(user_id).localpart
|
||||
|
||||
if not pattern:
|
||||
logger.warning("event_match condition with no pattern")
|
||||
return False
|
||||
|
||||
# XXX: optimisation: cache our pattern regexps
|
||||
if condition["key"] == "content.body":
|
||||
body = self._event.content.get("body", None)
|
||||
if not body or not isinstance(body, str):
|
||||
return False
|
||||
|
||||
return _glob_matches(pattern, body, word_boundary=True)
|
||||
else:
|
||||
haystack = self._value_cache.get(condition["key"], None)
|
||||
if haystack is None:
|
||||
return False
|
||||
|
||||
return _glob_matches(pattern, haystack)
|
||||
|
||||
def _contains_display_name(self, display_name: Optional[str]) -> bool:
|
||||
"""
|
||||
Check an "event_match" push rule condition.
|
||||
|
||||
Args:
|
||||
display_name: The display name, or None if there is not one.
|
||||
|
||||
Returns:
|
||||
True if the display name is found in the event body, False otherwise.
|
||||
"""
|
||||
if not display_name:
|
||||
return False
|
||||
|
||||
body = self._event.content.get("body", None)
|
||||
if not body or not isinstance(body, str):
|
||||
return False
|
||||
|
||||
# Similar to _glob_matches, but do not treat display_name as a glob.
|
||||
r = regex_cache.get((display_name, False, True), None)
|
||||
if not r:
|
||||
r1 = re.escape(display_name)
|
||||
r1 = to_word_pattern(r1)
|
||||
r = re.compile(r1, flags=re.IGNORECASE)
|
||||
regex_cache[(display_name, False, True)] = r
|
||||
|
||||
return bool(r.search(body))
|
||||
|
||||
def _relation_match(self, condition: Mapping, user_id: str) -> bool:
|
||||
"""
|
||||
Check an "relation_match" push rule condition.
|
||||
|
||||
Args:
|
||||
condition: The "event_match" push rule condition to match.
|
||||
user_id: The user's MXID.
|
||||
|
||||
Returns:
|
||||
True if the condition matches the event, False otherwise.
|
||||
"""
|
||||
rel_type = condition.get("rel_type")
|
||||
if not rel_type:
|
||||
logger.warning("relation_match condition missing rel_type")
|
||||
return False
|
||||
|
||||
sender_pattern = condition.get("sender")
|
||||
if sender_pattern is None:
|
||||
sender_type = condition.get("sender_type")
|
||||
if sender_type == "user_id":
|
||||
sender_pattern = user_id
|
||||
type_pattern = condition.get("type")
|
||||
|
||||
# If any other relations matches, return True.
|
||||
for sender, event_type in self._relations.get(rel_type, ()):
|
||||
if sender_pattern and not _glob_matches(sender_pattern, sender):
|
||||
continue
|
||||
if type_pattern and not _glob_matches(type_pattern, event_type):
|
||||
continue
|
||||
# All values must have matched.
|
||||
return True
|
||||
|
||||
# No relations matched.
|
||||
return False
|
||||
|
||||
|
||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
|
||||
50000, "regex_push_cache"
|
||||
)
|
||||
|
||||
|
||||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
||||
"""Tests if value matches glob.
|
||||
|
||||
Args:
|
||||
glob
|
||||
value: String to test against glob.
|
||||
word_boundary: Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
"""
|
||||
|
||||
try:
|
||||
r = regex_cache.get((glob, True, word_boundary), None)
|
||||
if not r:
|
||||
r = glob_to_regex(glob, word_boundary=word_boundary)
|
||||
regex_cache[(glob, True, word_boundary)] = r
|
||||
return bool(r.search(value))
|
||||
except re.error:
|
||||
logger.warning("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
||||
|
||||
def _flatten_dict(
|
||||
d: Union[EventBase, Mapping[str, Any]],
|
||||
prefix: Optional[List[str]] = None,
|
||||
result: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
if prefix is None:
|
||||
prefix = []
|
||||
if result is None:
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, str):
|
||||
result[".".join(prefix + [key])] = value.lower()
|
||||
elif isinstance(value, Mapping):
|
||||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||
|
||||
return result
|
|
@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
user_type: Optional[str],
|
||||
address: Optional[str],
|
||||
shadow_banned: bool,
|
||||
approved: bool,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
|
@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
or None for a normal user.
|
||||
address: the IP address used to perform the regitration.
|
||||
shadow_banned: Whether to shadow-ban the user
|
||||
approved: Whether the user should be considered already approved by an
|
||||
administrator.
|
||||
"""
|
||||
return {
|
||||
"password_hash": password_hash,
|
||||
|
@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
"user_type": user_type,
|
||||
"address": address,
|
||||
"shadow_banned": shadow_banned,
|
||||
"approved": approved,
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
|
@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
user_type=content["user_type"],
|
||||
address=content["address"],
|
||||
shadow_banned=content["shadow_banned"],
|
||||
approved=content["approved"],
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
|
|
@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
|
|||
self.store = hs.get_datastores().main
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
|
|||
guests = parse_boolean(request, "guests", default=True)
|
||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||
|
||||
# If support for MSC3866 is not enabled, apply no filtering based on the
|
||||
# `approved` column.
|
||||
if self._msc3866_enabled:
|
||||
approved = parse_boolean(request, "approved", default=True)
|
||||
else:
|
||||
approved = True
|
||||
|
||||
order_by = parse_string(
|
||||
request,
|
||||
"order_by",
|
||||
|
@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
|
|||
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||
|
||||
users, total = await self.store.get_users_paginate(
|
||||
start, limit, user_id, name, guests, deactivated, order_by, direction
|
||||
start,
|
||||
limit,
|
||||
user_id,
|
||||
name,
|
||||
guests,
|
||||
deactivated,
|
||||
order_by,
|
||||
direction,
|
||||
approved,
|
||||
)
|
||||
|
||||
# If support for MSC3866 is not enabled, don't show the approval flag.
|
||||
if not self._msc3866_enabled:
|
||||
for user in users:
|
||||
del user["approved"]
|
||||
|
||||
ret = {"users": users, "total": total}
|
||||
if (start + limit) < total:
|
||||
ret["next_token"] = str(start + len(users))
|
||||
|
@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
|
|||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
|
@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
|
|||
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
|
||||
)
|
||||
|
||||
approved: Optional[bool] = None
|
||||
if "approved" in body and self._msc3866_enabled:
|
||||
approved = body["approved"]
|
||||
if not isinstance(approved, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"'approved' parameter is not of type boolean",
|
||||
)
|
||||
|
||||
# convert List[Dict[str, str]] into List[Tuple[str, str]]
|
||||
if external_ids is not None:
|
||||
new_external_ids = [
|
||||
|
@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
|
|||
if "user_type" in body:
|
||||
await self.store.set_user_type(target_user, user_type)
|
||||
|
||||
if approved is not None:
|
||||
await self.store.update_user_approval_status(target_user, approved)
|
||||
|
||||
user = await self.admin_handler.get_user(target_user)
|
||||
assert user is not None
|
||||
|
||||
|
@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
|
|||
if password is not None:
|
||||
password_hash = await self.auth_handler.hash(password)
|
||||
|
||||
new_user_approved = True
|
||||
if self._msc3866_enabled and approved is not None:
|
||||
new_user_approved = approved
|
||||
|
||||
user_id = await self.registration_handler.register_user(
|
||||
localpart=target_user.localpart,
|
||||
password_hash=password_hash,
|
||||
|
@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
|
|||
default_display_name=displayname,
|
||||
user_type=user_type,
|
||||
by_admin=True,
|
||||
approved=new_user_approved,
|
||||
)
|
||||
|
||||
if threepids is not None:
|
||||
|
@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
|
|||
user_type=user_type,
|
||||
default_display_name=displayname,
|
||||
by_admin=True,
|
||||
approved=True,
|
||||
)
|
||||
|
||||
result = await register._create_registration_details(user_id, body)
|
||||
|
|
|
@ -28,7 +28,14 @@ from typing import (
|
|||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
|
||||
from synapse.api.constants import ApprovalNoticeMedium
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
InvalidClientTokenError,
|
||||
LoginError,
|
||||
NotApprovedError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.appservice import ApplicationService
|
||||
|
@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class LoginResponse(TypedDict, total=False):
|
||||
user_id: str
|
||||
access_token: str
|
||||
access_token: Optional[str]
|
||||
home_server: str
|
||||
expires_in_ms: Optional[int]
|
||||
refresh_token: Optional[str]
|
||||
device_id: str
|
||||
device_id: Optional[str]
|
||||
well_known: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
|
@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
|
|||
hs.config.registration.refreshable_access_token_lifetime is not None
|
||||
)
|
||||
|
||||
# Whether we need to check if the user has been approved or not.
|
||||
self._require_approval = (
|
||||
hs.config.experimental.msc3866.enabled
|
||||
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||
)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
|
|||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
if self._require_approval:
|
||||
approved = await self.auth_handler.is_user_approved(result["user_id"])
|
||||
if not approved:
|
||||
raise NotApprovedError(
|
||||
msg="This account is pending approval by a server administrator.",
|
||||
approval_notice_medium=ApprovalNoticeMedium.NONE,
|
||||
)
|
||||
|
||||
well_known_data = self._well_known_builder.get_well_known()
|
||||
if well_known_data:
|
||||
result["well_known"] = well_known_data
|
||||
|
@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
|
|||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if self._require_approval:
|
||||
approved = await self.auth_handler.is_user_approved(user_id)
|
||||
if not approved:
|
||||
# If the user isn't approved (and needs to be) we won't allow them to
|
||||
# actually log in, so we don't want to create a device/access token.
|
||||
return LoginResponse(
|
||||
user_id=user_id,
|
||||
home_server=self.hs.hostname,
|
||||
)
|
||||
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
(
|
||||
device_id,
|
||||
|
|
|
@ -47,7 +47,9 @@ class LoginTokenRequestServlet(RestServlet):
|
|||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/login/token$")
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
|
|
@ -21,10 +21,15 @@ from twisted.web.server import Request
|
|||
import synapse
|
||||
import synapse.api.auth
|
||||
import synapse.types
|
||||
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
||||
from synapse.api.constants import (
|
||||
APP_SERVICE_REGISTRATION_TYPE,
|
||||
ApprovalNoticeMedium,
|
||||
LoginType,
|
||||
)
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
InteractiveAuthIncompleteError,
|
||||
NotApprovedError,
|
||||
SynapseError,
|
||||
ThreepidValidationError,
|
||||
UnrecognizedRequestError,
|
||||
|
@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet):
|
|||
hs.config.registration.inhibit_user_in_use_error
|
||||
)
|
||||
|
||||
self._require_approval = (
|
||||
hs.config.experimental.msc3866.enabled
|
||||
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||
)
|
||||
|
||||
self._registration_flows = _calculate_registration_flows(
|
||||
hs.config, self.auth_handler
|
||||
)
|
||||
|
@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet):
|
|||
access_token=return_dict.get("access_token"),
|
||||
)
|
||||
|
||||
if self._require_approval:
|
||||
raise NotApprovedError(
|
||||
msg="This account needs to be approved by an administrator before it can be used.",
|
||||
approval_notice_medium=ApprovalNoticeMedium.NONE,
|
||||
)
|
||||
|
||||
return 200, return_dict
|
||||
|
||||
async def _do_appservice_registration(
|
||||
|
@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet):
|
|||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
if not params.get("inhibit_login", False):
|
||||
# We don't want to log the user in if we're going to deny them access because
|
||||
# they need to be approved first.
|
||||
if not params.get("inhibit_login", False) and not self._require_approval:
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
(
|
||||
|
|
|
@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
self._attempt_to_invalidate_cache(
|
||||
"get_rooms_for_user_with_stream_ordering", (user_id,)
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
|
||||
|
||||
# Purge other caches based on room state.
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
|
|
|
@ -423,16 +423,18 @@ class EventsPersistenceStorageController:
|
|||
for d in ret_vals:
|
||||
replaced_events.update(d)
|
||||
|
||||
events = []
|
||||
persisted_events = []
|
||||
for event, _ in events_and_contexts:
|
||||
existing_event_id = replaced_events.get(event.event_id)
|
||||
if existing_event_id:
|
||||
events.append(await self.main_store.get_event(existing_event_id))
|
||||
persisted_events.append(
|
||||
await self.main_store.get_event(existing_event_id)
|
||||
)
|
||||
else:
|
||||
events.append(event)
|
||||
persisted_events.append(event)
|
||||
|
||||
return (
|
||||
events,
|
||||
persisted_events,
|
||||
self.main_store.get_room_max_token(),
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from typing import (
|
|||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
@ -529,7 +529,18 @@ class StateStorageController:
|
|||
)
|
||||
return state_map.get(key)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state.
|
||||
|
||||
Blocks until we have full state for the given room. This only happens for rooms
|
||||
with partial state.
|
||||
"""
|
||||
|
||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||
|
||||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
||||
"""Get current hosts in room based on current state.
|
||||
|
||||
Blocks until we have full state for the given room. This only happens for rooms
|
||||
|
@ -542,11 +553,11 @@ class StateStorageController:
|
|||
|
||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||
return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
|
||||
|
||||
async def get_current_hosts_in_room_or_partial_state_approximation(
|
||||
self, room_id: str
|
||||
) -> Sequence[str]:
|
||||
) -> Collection[str]:
|
||||
"""Get approximation of current hosts in room based on current state.
|
||||
|
||||
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
|
||||
|
@ -566,14 +577,9 @@ class StateStorageController:
|
|||
)
|
||||
|
||||
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
|
||||
hosts_from_state_set = set(hosts_from_state)
|
||||
|
||||
# First take the list of hosts based on the current state.
|
||||
# For rooms with partial state, this will be missing most hosts.
|
||||
hosts = list(hosts_from_state)
|
||||
# Then add in the list of hosts in the room at the time we joined.
|
||||
# This will be an empty list for rooms with full state.
|
||||
hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
|
||||
hosts = set(hosts_at_join)
|
||||
hosts.update(hosts_from_state)
|
||||
|
||||
return hosts
|
||||
|
||||
|
|
|
@ -1141,17 +1141,57 @@ class DatabasePool:
|
|||
desc: str = "simple_upsert",
|
||||
lock: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
"""Insert a row with values + insertion_values; on conflict, update with values.
|
||||
|
||||
`lock` should generally be set to True (the default), but can be set
|
||||
to False if either of the following are true:
|
||||
1. there is a UNIQUE INDEX on the key columns. In this case a conflict
|
||||
will cause an IntegrityError in which case this function will retry
|
||||
the update.
|
||||
2. we somehow know that we are the only thread which will be updating
|
||||
this table.
|
||||
As an additional note, this parameter only matters for old SQLite versions
|
||||
because we will use native upserts otherwise.
|
||||
All of our supported databases accept the nonstandard "upsert" statement in
|
||||
their dialect of SQL. We call this a "native upsert". The syntax looks roughly
|
||||
like:
|
||||
|
||||
INSERT INTO table VALUES (values + insertion_values)
|
||||
ON CONFLICT (keyvalues)
|
||||
DO UPDATE SET (values); -- overwrite `values` columns only
|
||||
|
||||
If (values) is empty, the resulting query is slighlty simpler:
|
||||
|
||||
INSERT INTO table VALUES (insertion_values)
|
||||
ON CONFLICT (keyvalues)
|
||||
DO NOTHING; -- do not overwrite any columns
|
||||
|
||||
This function is a helper to build such queries.
|
||||
|
||||
In order for upserts to make sense, the database must be able to determine when
|
||||
an upsert CONFLICTs with an existing row. Postgres and SQLite ensure this by
|
||||
requiring that a unique index exist on the column names used to detect a
|
||||
conflict (i.e. `keyvalues.keys()`).
|
||||
|
||||
If there is no such index, we can "emulate" an upsert with a SELECT followed
|
||||
by either an INSERT or an UPDATE. This is unsafe: we cannot make the same
|
||||
atomicity guarantees that a native upsert can and are very vulnerable to races
|
||||
and crashes. Therefore if we wish to upsert without an appropriate unique index,
|
||||
we must either:
|
||||
|
||||
1. Acquire a table-level lock before the emulated upsert (`lock=True`), or
|
||||
2. VERY CAREFULLY ensure that we are the only thread and worker which will be
|
||||
writing to this table, in which case we can proceed without a lock
|
||||
(`lock=False`).
|
||||
|
||||
Generally speaking, you should use `lock=True`. If the table in question has a
|
||||
unique index[*], this class will use a native upsert (which is atomic and so can
|
||||
ignore the `lock` argument). Otherwise this class will use an emulated upsert,
|
||||
in which case we want the safer option unless we been VERY CAREFUL.
|
||||
|
||||
[*]: Some tables have unique indices added to them in the background. Those
|
||||
tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES,
|
||||
where `T` maps to the background update that adds a unique index to `T`.
|
||||
This dictionary is maintained by hand.
|
||||
|
||||
At runtime, we constantly check to see if each of these background updates
|
||||
has run. If so, we deem the coresponding table safe to upsert into, because
|
||||
we can now use a native insert to do so. If not, we deem the table unsafe
|
||||
to upsert into and require an emulated upsert.
|
||||
|
||||
Tables that do not appear in this dictionary are assumed to have an
|
||||
appropriate unique index and therefore be safe to upsert into.
|
||||
|
||||
Args:
|
||||
table: The table to upsert into
|
||||
|
|
|
@ -203,6 +203,7 @@ class DataStore(
|
|||
deactivated: bool = False,
|
||||
order_by: str = UserSortOrder.USER_ID.value,
|
||||
direction: str = "f",
|
||||
approved: bool = True,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
|
@ -217,6 +218,7 @@ class DataStore(
|
|||
deactivated: whether to include deactivated users
|
||||
order_by: the sort order of the returned list
|
||||
direction: sort ascending or descending
|
||||
approved: whether to include approved users
|
||||
Returns:
|
||||
A tuple of a list of mappings from user to information and a count of total users.
|
||||
"""
|
||||
|
@ -249,6 +251,11 @@ class DataStore(
|
|||
if not deactivated:
|
||||
filters.append("deactivated = 0")
|
||||
|
||||
if not approved:
|
||||
# We ignore NULL values for the approved flag because these should only
|
||||
# be already existing users that we consider as already approved.
|
||||
filters.append("approved IS FALSE")
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||
|
||||
sql_base = f"""
|
||||
|
@ -262,7 +269,7 @@ class DataStore(
|
|||
|
||||
sql = f"""
|
||||
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
|
||||
displayname, avatar_url, creation_ts * 1000 as creation_ts
|
||||
displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
|
||||
{sql_base}
|
||||
ORDER BY {order_by_column} {order}, u.name ASC
|
||||
LIMIT ? OFFSET ?
|
||||
|
|
|
@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
|||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||
(data.state_key,)
|
||||
)
|
||||
self.get_rooms_for_user.invalidate((data.state_key,))
|
||||
else:
|
||||
raise Exception("Unknown events stream row type %s" % (row.type,))
|
||||
|
||||
|
|
|
@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
|
||||
return changes
|
||||
|
||||
async def get_device_list_changes_in_room(
|
||||
self, room_id: str, min_stream_id: int
|
||||
) -> Collection[Tuple[str, str]]:
|
||||
"""Get all device list changes that happened in the room since the given
|
||||
stream ID.
|
||||
|
||||
Returns:
|
||||
Collection of user ID/device ID tuples of all devices that have
|
||||
changed
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
|
||||
WHERE room_id = ? AND stream_id > ?
|
||||
"""
|
||||
|
||||
def get_device_list_changes_in_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Collection[Tuple[str, str]]:
|
||||
txn.execute(sql, (room_id, min_stream_id))
|
||||
return cast(Collection[Tuple[str, str]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_device_list_changes_in_room",
|
||||
get_device_list_changes_in_room_txn,
|
||||
)
|
||||
|
||||
|
||||
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
def __init__(
|
||||
|
@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
user_id: str,
|
||||
device_id: str,
|
||||
room_id: str,
|
||||
stream_id: int,
|
||||
stream_id: Optional[int],
|
||||
hosts: Collection[str],
|
||||
context: Optional[Dict[str, str]],
|
||||
) -> None:
|
||||
"""Queue the device update to be sent to the given set of hosts,
|
||||
calculated from the room ID.
|
||||
|
||||
Marks the associated row in `device_lists_changes_in_room` as handled.
|
||||
Marks the associated row in `device_lists_changes_in_room` as handled,
|
||||
if `stream_id` is provided.
|
||||
"""
|
||||
|
||||
def add_device_list_outbound_pokes_txn(
|
||||
|
@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
context=context,
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="device_lists_changes_in_room",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"stream_id": stream_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={"converted_to_destinations": True},
|
||||
)
|
||||
if stream_id:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="device_lists_changes_in_room",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"stream_id": stream_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={"converted_to_destinations": True},
|
||||
)
|
||||
|
||||
if not hosts:
|
||||
# If there are no hosts then we don't try and generate stream IDs.
|
||||
|
@ -1995,3 +2024,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
add_device_list_outbound_pokes_txn,
|
||||
stream_ids,
|
||||
)
|
||||
|
||||
async def add_remote_device_list_to_pending(
|
||||
self, user_id: str, device_id: str
|
||||
) -> None:
|
||||
"""Add a device list update to the table tracking remote device list
|
||||
updates during partial joins.
|
||||
"""
|
||||
|
||||
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||
await self.db_pool.simple_upsert(
|
||||
table="device_lists_remote_pending",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={"stream_id": stream_id},
|
||||
desc="add_remote_device_list_to_pending",
|
||||
)
|
||||
|
||||
async def get_pending_remote_device_list_updates_for_room(
|
||||
self, room_id: str
|
||||
) -> Collection[Tuple[str, str]]:
|
||||
"""Get the set of remote device list updates from the pending table for
|
||||
the room.
|
||||
"""
|
||||
|
||||
min_device_stream_id = await self.db_pool.simple_select_one_onecol(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcol="device_lists_stream_id",
|
||||
desc="get_pending_remote_device_list_updates_for_room_device",
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT user_id, device_id FROM device_lists_remote_pending AS d
|
||||
INNER JOIN current_state_events AS c ON
|
||||
type = 'm.room.member'
|
||||
AND state_key = user_id
|
||||
AND membership = 'join'
|
||||
WHERE
|
||||
room_id = ? AND stream_id > ?
|
||||
"""
|
||||
|
||||
def get_pending_remote_device_list_updates_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Collection[Tuple[str, str]]:
|
||||
txn.execute(sql, (room_id, min_device_stream_id))
|
||||
return cast(Collection[Tuple[str, str]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_pending_remote_device_list_updates_for_room",
|
||||
get_pending_remote_device_list_updates_for_room_txn,
|
||||
)
|
||||
|
|
|
@ -73,13 +73,30 @@ pdus_pruned_from_federation_queue = Counter(
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int(
|
||||
datetime.timedelta(days=7).total_seconds()
|
||||
)
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int(
|
||||
datetime.timedelta(hours=1).total_seconds()
|
||||
# Parameters controlling exponential backoff between backfill failures.
|
||||
# After the first failure to backfill, we wait 2 hours before trying again. If the
|
||||
# second attempt fails, we wait 4 hours before trying again. If the third attempt fails,
|
||||
# we wait 8 hours before trying again, ... and so on.
|
||||
#
|
||||
# Each successive backoff period is twice as long as the last. However we cap this
|
||||
# period at a maximum of 2^8 = 256 hours: a little over 10 days. (This is the smallest
|
||||
# power of 2 which yields a maximum backoff period of at least 7 days---which was the
|
||||
# original maximum backoff period.) Even when we hit this cap, we will continue to
|
||||
# make backfill attempts once every 10 days.
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS = 8
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS = int(
|
||||
datetime.timedelta(hours=1).total_seconds() * 1000
|
||||
)
|
||||
|
||||
# We need a cap on the power of 2 or else the backoff period
|
||||
# 2^N * (milliseconds per hour)
|
||||
# will overflow when calcuated within the database. We ensure overflow does not occur
|
||||
# by checking that the largest backoff period fits in a 32-bit signed integer.
|
||||
_LONGEST_BACKOFF_PERIOD_MILLISECONDS = (
|
||||
2**BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS
|
||||
) * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
|
||||
assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1)
|
||||
|
||||
|
||||
# All the info we need while iterating the DAG while backfilling
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
|
@ -726,17 +743,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
async def get_backfill_points_in_room(
|
||||
self,
|
||||
room_id: str,
|
||||
current_depth: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Gets the oldest events(backwards extremities) in the room along with the
|
||||
approximate depth. Sorted by depth, highest to lowest (descending).
|
||||
Get the backward extremities to backfill from in the room along with the
|
||||
approximate depth.
|
||||
|
||||
Only returns events that are at a depth lower than or
|
||||
equal to the `current_depth`. Sorted by depth, highest to lowest (descending)
|
||||
so the closest events to the `current_depth` are first in the list.
|
||||
|
||||
We ignore extremities that are newer than the user's current scroll position
|
||||
(ie, those with depth greater than `current_depth`) as:
|
||||
1. we don't really care about getting events that have happened
|
||||
after our current position; and
|
||||
2. by the nature of paginating and scrolling back, we have likely
|
||||
previously tried and failed to backfill from that extremity, so
|
||||
to avoid getting "stuck" requesting the same backfill repeatedly
|
||||
we drop those extremities.
|
||||
|
||||
Args:
|
||||
room_id: Room where we want to find the oldest events
|
||||
current_depth: The depth at the user's current scrollback position
|
||||
limit: The max number of backfill points to return
|
||||
|
||||
Returns:
|
||||
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
||||
(descending)
|
||||
(descending) so the closest events to the `current_depth` are first
|
||||
in the list.
|
||||
"""
|
||||
|
||||
def get_backfill_points_in_room_txn(
|
||||
|
@ -749,7 +784,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# persisted in our database yet (meaning we don't know their depth
|
||||
# specifically). So we need to look for the approximate depth from
|
||||
# the events connected to the current backwards extremeties.
|
||||
sql = """
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
least_function = "LEAST"
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
least_function = "MIN"
|
||||
else:
|
||||
raise RuntimeError("Unknown database engine")
|
||||
|
||||
sql = f"""
|
||||
SELECT backward_extrem.event_id, event.depth FROM events AS event
|
||||
/**
|
||||
* Get the edge connections from the event_edges table
|
||||
|
@ -784,6 +827,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
* necessarily safe to assume that it will have been completed.
|
||||
*/
|
||||
AND edge.is_state is ? /* False */
|
||||
/**
|
||||
* We only want backwards extremities that are older than or at
|
||||
* the same position of the given `current_depth` (where older
|
||||
* means less than the given depth) because we're looking backwards
|
||||
* from the `current_depth` when backfilling.
|
||||
*
|
||||
* current_depth (ignore events that come after this, ignore 2-4)
|
||||
* |
|
||||
* ▼
|
||||
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
|
||||
*/
|
||||
AND event.depth <= ? /* current_depth */
|
||||
/**
|
||||
* Exponential back-off (up to the upper bound) so we don't retry the
|
||||
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
|
||||
|
@ -795,31 +850,31 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
*/
|
||||
AND (
|
||||
failed_backfill_attempt_info.event_id IS NULL
|
||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
|
||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
|
||||
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
|
||||
* ? /* step */
|
||||
)
|
||||
)
|
||||
/**
|
||||
* Sort from highest to the lowest depth. Then tie-break on
|
||||
* alphabetical order of the event_ids so we get a consistent
|
||||
* ordering which is nice when asserting things in tests.
|
||||
* Sort from highest (closest to the `current_depth`) to the lowest depth
|
||||
* because the closest are most relevant to backfill from first.
|
||||
* Then tie-break on alphabetical order of the event_ids so we get a
|
||||
* consistent ordering which is nice when asserting things in tests.
|
||||
*/
|
||||
ORDER BY event.depth DESC, backward_extrem.event_id DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
least_function = "least"
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
least_function = "min"
|
||||
else:
|
||||
raise RuntimeError("Unknown database engine")
|
||||
|
||||
txn.execute(
|
||||
sql % (least_function,),
|
||||
sql,
|
||||
(
|
||||
room_id,
|
||||
False,
|
||||
current_depth,
|
||||
self._clock.time_msec(),
|
||||
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
|
||||
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
|
||||
limit,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -835,24 +890,47 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
async def get_insertion_event_backward_extremities_in_room(
|
||||
self,
|
||||
room_id: str,
|
||||
current_depth: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Get the insertion events we know about that we haven't backfilled yet
|
||||
along with the approximate depth. Sorted by depth, highest to lowest
|
||||
(descending).
|
||||
along with the approximate depth. Only returns insertion events that are
|
||||
at a depth lower than or equal to the `current_depth`. Sorted by depth,
|
||||
highest to lowest (descending) so the closest events to the
|
||||
`current_depth` are first in the list.
|
||||
|
||||
We ignore insertion events that are newer than the user's current scroll
|
||||
position (ie, those with depth greater than `current_depth`) as:
|
||||
1. we don't really care about getting events that have happened
|
||||
after our current position; and
|
||||
2. by the nature of paginating and scrolling back, we have likely
|
||||
previously tried and failed to backfill from that insertion event, so
|
||||
to avoid getting "stuck" requesting the same backfill repeatedly
|
||||
we drop those insertion event.
|
||||
|
||||
Args:
|
||||
room_id: Room where we want to find the oldest events
|
||||
current_depth: The depth at the user's current scrollback position
|
||||
limit: The max number of insertion event extremities to return
|
||||
|
||||
Returns:
|
||||
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
|
||||
(descending)
|
||||
(descending) so the closest events to the `current_depth` are first
|
||||
in the list.
|
||||
"""
|
||||
|
||||
def get_insertion_event_backward_extremities_in_room_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
sql = """
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
least_function = "LEAST"
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
least_function = "MIN"
|
||||
else:
|
||||
raise RuntimeError("Unknown database engine")
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
insertion_event_extremity.event_id, event.depth
|
||||
/* We only want insertion events that are also marked as backwards extremities */
|
||||
|
@ -869,6 +947,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id
|
||||
WHERE
|
||||
insertion_event_extremity.room_id = ?
|
||||
/**
|
||||
* We only want extremities that are older than or at
|
||||
* the same position of the given `current_depth` (where older
|
||||
* means less than the given depth) because we're looking backwards
|
||||
* from the `current_depth` when backfilling.
|
||||
*
|
||||
* current_depth (ignore events that come after this, ignore 2-4)
|
||||
* |
|
||||
* ▼
|
||||
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
|
||||
*/
|
||||
AND event.depth <= ? /* current_depth */
|
||||
/**
|
||||
* Exponential back-off (up to the upper bound) so we don't retry the
|
||||
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc
|
||||
|
@ -880,30 +970,30 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
*/
|
||||
AND (
|
||||
failed_backfill_attempt_info.event_id IS NULL
|
||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
|
||||
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
|
||||
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
|
||||
* ? /* step */
|
||||
)
|
||||
)
|
||||
/**
|
||||
* Sort from highest to the lowest depth. Then tie-break on
|
||||
* alphabetical order of the event_ids so we get a consistent
|
||||
* ordering which is nice when asserting things in tests.
|
||||
* Sort from highest (closest to the `current_depth`) to the lowest depth
|
||||
* because the closest are most relevant to backfill from first.
|
||||
* Then tie-break on alphabetical order of the event_ids so we get a
|
||||
* consistent ordering which is nice when asserting things in tests.
|
||||
*/
|
||||
ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
least_function = "least"
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
least_function = "min"
|
||||
else:
|
||||
raise RuntimeError("Unknown database engine")
|
||||
|
||||
txn.execute(
|
||||
sql % (least_function,),
|
||||
sql,
|
||||
(
|
||||
room_id,
|
||||
current_depth,
|
||||
self._clock.time_msec(),
|
||||
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
|
||||
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
|
||||
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
|
||||
limit,
|
||||
),
|
||||
)
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
|
|
|
@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
user_id: str,
|
||||
) -> NotifCounts:
|
||||
# Get the stream ordering of the user's latest receipt in the room.
|
||||
result = self.get_last_receipt_for_user_txn(
|
||||
result = self.get_last_unthreaded_receipt_for_user_txn(
|
||||
txn,
|
||||
user_id,
|
||||
room_id,
|
||||
receipt_types=(
|
||||
ReceiptTypes.READ,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
),
|
||||
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||
)
|
||||
|
||||
if result:
|
||||
|
@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
receipt_types_clause, args = make_in_list_sql_clause(
|
||||
self.database_engine,
|
||||
"receipt_type",
|
||||
(
|
||||
ReceiptTypes.READ,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
),
|
||||
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
|
@ -1074,7 +1068,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
limit,
|
||||
),
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
rows = cast(List[Tuple[int, str, str, int]], txn.fetchall())
|
||||
|
||||
# For each new read receipt we delete push actions from before it and
|
||||
# recalculate the summary.
|
||||
|
@ -1119,18 +1113,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
# We always update `event_push_summary_last_receipt_stream_id` to
|
||||
# ensure that we don't rescan the same receipts for remote users.
|
||||
|
||||
upper_limit = max_receipts_stream_id
|
||||
receipts_last_processed_stream_id = max_receipts_stream_id
|
||||
if len(rows) >= limit:
|
||||
# If we pulled out a limited number of rows we only update the
|
||||
# position to the last receipt we processed, so we continue
|
||||
# processing the rest next iteration.
|
||||
upper_limit = rows[-1][0]
|
||||
receipts_last_processed_stream_id = rows[-1][0]
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="event_push_summary_last_receipt_stream_id",
|
||||
keyvalues={},
|
||||
updatevalues={"stream_id": upper_limit},
|
||||
updatevalues={"stream_id": receipts_last_processed_stream_id},
|
||||
)
|
||||
|
||||
return len(rows) < limit
|
||||
|
|
|
@ -2134,13 +2134,13 @@ class PersistEventsStore:
|
|||
appear in events_and_context.
|
||||
"""
|
||||
|
||||
# Only non outlier events will have push actions associated with them,
|
||||
# Only notifiable events will have push actions associated with them,
|
||||
# so let's filter them out. (This makes joining large rooms faster, as
|
||||
# these queries took seconds to process all the state events).
|
||||
non_outlier_events = [
|
||||
notifiable_events = [
|
||||
event
|
||||
for event, _ in events_and_contexts
|
||||
if not event.internal_metadata.is_outlier()
|
||||
if event.internal_metadata.is_notifiable()
|
||||
]
|
||||
|
||||
sql = """
|
||||
|
@ -2153,7 +2153,7 @@ class PersistEventsStore:
|
|||
WHERE event_id = ?
|
||||
"""
|
||||
|
||||
if non_outlier_events:
|
||||
if notifiable_events:
|
||||
txn.execute_batch(
|
||||
sql,
|
||||
(
|
||||
|
@ -2163,7 +2163,7 @@ class PersistEventsStore:
|
|||
event.depth,
|
||||
event.event_id,
|
||||
)
|
||||
for event in non_outlier_events
|
||||
for event in notifiable_events
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
"""Get the current max stream ID for receipts stream"""
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
async def get_last_receipt_event_id_for_user(
|
||||
self, user_id: str, room_id: str, receipt_types: Collection[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch the event ID for the latest receipt in a room with one of the given receipt types.
|
||||
|
||||
Args:
|
||||
user_id: The user to fetch receipts for.
|
||||
room_id: The room ID to fetch the receipt for.
|
||||
receipt_type: The receipt types to fetch.
|
||||
|
||||
Returns:
|
||||
The latest receipt, if one exists.
|
||||
"""
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_last_receipt_event_id_for_user",
|
||||
self.get_last_receipt_for_user_txn,
|
||||
user_id,
|
||||
room_id,
|
||||
receipt_types,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
event_id, _ = result
|
||||
return event_id
|
||||
|
||||
def get_last_receipt_for_user_txn(
|
||||
def get_last_unthreaded_receipt_for_user_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
|
@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
receipt_types: Collection[str],
|
||||
) -> Optional[Tuple[str, int]]:
|
||||
"""
|
||||
Fetch the event ID and stream_ordering for the latest receipt in a room
|
||||
with one of the given receipt types.
|
||||
Fetch the event ID and stream_ordering for the latest unthreaded receipt
|
||||
in a room with one of the given receipt types.
|
||||
|
||||
Args:
|
||||
user_id: The user to fetch receipts for.
|
||||
room_id: The room ID to fetch the receipt for.
|
||||
receipt_type: The receipt types to fetch.
|
||||
receipt_types: The receipt types to fetch.
|
||||
|
||||
Returns:
|
||||
The event ID and stream ordering of the latest receipt, if one exists.
|
||||
|
@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
WHERE {clause}
|
||||
AND user_id = ?
|
||||
AND room_id = ?
|
||||
AND thread_id IS NULL
|
||||
ORDER BY stream_ordering DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
|
|
@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
@cached()
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Deprecated: use get_userinfo_by_id instead"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin",
|
||||
"consent_version",
|
||||
"consent_ts",
|
||||
"consent_server_notice_sent",
|
||||
"appservice_id",
|
||||
"creation_ts",
|
||||
"user_type",
|
||||
"deactivated",
|
||||
"shadow_banned",
|
||||
],
|
||||
allow_none=True,
|
||||
|
||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
# We could technically use simple_select_one here, but it would not perform
|
||||
# the COALESCEs (unless hacked into the column names), which could yield
|
||||
# confusing results.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
||||
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||
COALESCE(approved, TRUE) AS approved
|
||||
FROM users
|
||||
WHERE name = ?
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if len(rows) == 0:
|
||||
return None
|
||||
|
||||
return rows[0]
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
desc="get_user_by_id",
|
||||
func=get_user_by_id_txn,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# If we're using SQLite our boolean values will be integers. Because we
|
||||
# present some of this data as is to e.g. server admins via REST APIs, we
|
||||
# want to make sure we're returning the right type of data.
|
||||
# Note: when adding a column name to this list, be wary of NULLable columns,
|
||||
# since NULL values will be turned into False.
|
||||
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
|
||||
for column in boolean_columns:
|
||||
if not isinstance(row[column], bool):
|
||||
row[column] = bool(row[column])
|
||||
|
||||
return row
|
||||
|
||||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Get a UserInfo object for a user by user ID.
|
||||
|
||||
|
@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
return res if res else False
|
||||
|
||||
@cached()
|
||||
async def is_user_approved(self, user_id: str) -> bool:
|
||||
"""Checks if a user is approved and therefore can be allowed to log in.
|
||||
|
||||
If the user's 'approved' column is NULL, we consider it as true given it means
|
||||
the user was registered when support for an approval flow was either disabled
|
||||
or nonexistent.
|
||||
|
||||
Args:
|
||||
user_id: the user to check the approval status of.
|
||||
|
||||
Returns:
|
||||
A boolean that is True if the user is approved, False otherwise.
|
||||
"""
|
||||
|
||||
def is_user_approved_txn(txn: LoggingTransaction) -> bool:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
# We cast to bool because the value returned by the database engine might
|
||||
# be an integer if we're using SQLite.
|
||||
return bool(rows[0]["approved"])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="is_user_pending_approval",
|
||||
func=is_user_approved_txn,
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
|
@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||
|
||||
def update_user_approval_status_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, approved: bool
|
||||
) -> None:
|
||||
"""Set the user's 'approved' flag to the given value.
|
||||
|
||||
The boolean is turned into an int because the column is a smallint.
|
||||
|
||||
Args:
|
||||
txn: the current database transaction.
|
||||
user_id: the user to update the flag for.
|
||||
approved: the value to set the flag to.
|
||||
"""
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn=txn,
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
updatevalues={"approved": approved},
|
||||
)
|
||||
|
||||
# Invalidate the caches of methods that read the value of the 'approved' flag.
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
|
||||
|
||||
|
||||
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
def __init__(
|
||||
|
@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||
|
||||
# If support for MSC3866 is enabled and configured to require approval for new
|
||||
# account, we will create new users with an 'approved' flag set to false.
|
||||
self._require_approval = (
|
||||
hs.config.experimental.msc3866.enabled
|
||||
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||
)
|
||||
|
||||
async def add_access_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
admin: bool = False,
|
||||
user_type: Optional[str] = None,
|
||||
shadow_banned: bool = False,
|
||||
approved: bool = False,
|
||||
) -> None:
|
||||
"""Attempts to register an account.
|
||||
|
||||
|
@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
or None for a normal user.
|
||||
shadow_banned: Whether the user is shadow-banned, i.e. they may be
|
||||
told their requests succeeded but we ignore them.
|
||||
approved: Whether to consider the user has already been approved by an
|
||||
administrator.
|
||||
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
|
@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
admin,
|
||||
user_type,
|
||||
shadow_banned,
|
||||
approved,
|
||||
)
|
||||
|
||||
def _register_user(
|
||||
|
@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
admin: bool,
|
||||
user_type: Optional[str],
|
||||
shadow_banned: bool,
|
||||
approved: bool,
|
||||
) -> None:
|
||||
user_id_obj = UserID.from_string(user_id)
|
||||
|
||||
now = int(self._clock.time())
|
||||
|
||||
user_approved = approved or not self._require_approval
|
||||
|
||||
try:
|
||||
if was_guest:
|
||||
# Ensure that the guest user actually exists
|
||||
|
@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
"shadow_banned": shadow_banned,
|
||||
"approved": user_approved,
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"admin": 1 if admin else 0,
|
||||
"user_type": user_type,
|
||||
"shadow_banned": shadow_banned,
|
||||
"approved": user_approved,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
start_or_continue_validation_session_txn,
|
||||
)
|
||||
|
||||
async def update_user_approval_status(
|
||||
self, user_id: UserID, approved: bool
|
||||
) -> None:
|
||||
"""Set the user's 'approved' flag to the given value.
|
||||
|
||||
The boolean will be turned into an int (in update_user_approval_status_txn)
|
||||
because the column is a smallint.
|
||||
|
||||
Args:
|
||||
user_id: the user to update the flag for.
|
||||
approved: the value to set the flag to.
|
||||
"""
|
||||
await self.db_pool.runInteraction(
|
||||
"update_user_approval_status",
|
||||
self.update_user_approval_status_txn,
|
||||
user_id.to_string(),
|
||||
approved,
|
||||
)
|
||||
|
||||
|
||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||
"""
|
||||
|
|
|
@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
|
||||
|
||||
# We now delete anything from `device_lists_remote_pending` with a
|
||||
# stream ID less than the minimum
|
||||
# `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
|
||||
device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="partial_state_rooms",
|
||||
keyvalues={},
|
||||
retcol="MIN(device_lists_stream_id)",
|
||||
allow_none=True,
|
||||
)
|
||||
if device_lists_stream_id is None:
|
||||
# There are no rooms being currently partially joined, so we delete everything.
|
||||
txn.execute("DELETE FROM device_lists_remote_pending")
|
||||
else:
|
||||
sql = """
|
||||
DELETE FROM device_lists_remote_pending
|
||||
WHERE stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (device_lists_stream_id,))
|
||||
|
||||
@cached()
|
||||
async def is_partial_state_room(self, room_id: str) -> bool:
|
||||
"""Checks if this room has partial state.
|
||||
|
@ -1236,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
return entry is not None
|
||||
|
||||
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
|
||||
self, room_id: str
|
||||
) -> Tuple[str, int]:
|
||||
"""Get the event ID of the initial join that started the partial
|
||||
join, and the device list stream ID at the point we started the partial
|
||||
join.
|
||||
"""
|
||||
|
||||
result = await self.db_pool.simple_select_one(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("join_event_id", "device_lists_stream_id"),
|
||||
desc="get_join_event_id_for_partial_state",
|
||||
)
|
||||
return result["join_event_id"], result["device_lists_stream_id"]
|
||||
|
||||
|
||||
class _BackgroundUpdates:
|
||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
|
@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
|
|||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -148,42 +146,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
async def get_users_in_room(self, room_id: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of users in the room sorted by longest in the room first
|
||||
(aka. with the lowest depth). This is done to match the sort in
|
||||
`get_current_hosts_in_room()` and so we can re-use the cache but it's
|
||||
not horrible to have here either.
|
||||
|
||||
Uses `m.room.member`s in the room state at the current forward extremities to
|
||||
determine which users are in the room.
|
||||
"""Returns a list of users in the room.
|
||||
|
||||
Will return inaccurate results for rooms with partial state, since the state for
|
||||
the forward extremities of those rooms will exclude most members. We may also
|
||||
calculate room state incorrectly for such rooms and believe that a member is or
|
||||
is not in the room when the opposite is true.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="current_state_events",
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"room_id": room_id,
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
retcol="state_key",
|
||||
desc="get_users_in_room",
|
||||
)
|
||||
|
||||
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
||||
"""
|
||||
Returns a list of users in the room sorted by longest in the room first
|
||||
(aka. with the lowest depth). This is done to match the sort in
|
||||
`get_current_hosts_in_room()` and so we can re-use the cache but it's
|
||||
not horrible to have here either.
|
||||
"""
|
||||
sql = """
|
||||
SELECT c.state_key FROM current_state_events as c
|
||||
/* Get the depth of the event from the events table */
|
||||
INNER JOIN events AS e USING (event_id)
|
||||
WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
|
||||
/* Sorted by lowest depth first */
|
||||
ORDER BY e.depth ASC;
|
||||
"""
|
||||
"""Returns a list of users in the room."""
|
||||
|
||||
txn.execute(sql, (room_id, Membership.JOIN))
|
||||
return [r[0] for r in txn]
|
||||
return self.db_pool.simple_select_onecol_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"room_id": room_id,
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
retcol="state_key",
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_in_room_with_profile(
|
||||
|
@ -600,58 +593,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
for room_id, instance, stream_id in txn
|
||||
)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_rooms_for_user_with_stream_ordering",
|
||||
list_name="user_ids",
|
||||
)
|
||||
async def get_rooms_for_users_with_stream_ordering(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
||||
"""A batched version of `get_rooms_for_user_with_stream_ordering`.
|
||||
|
||||
Returns:
|
||||
Map from user_id to set of rooms that is currently in.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_for_users_with_stream_ordering",
|
||||
self._get_rooms_for_users_with_stream_ordering_txn,
|
||||
user_ids,
|
||||
)
|
||||
|
||||
def _get_rooms_for_users_with_stream_ordering_txn(
|
||||
self, txn: LoggingTransaction, user_ids: Collection[str]
|
||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine,
|
||||
"c.state_key",
|
||||
user_ids,
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
|
||||
FROM current_state_events AS c
|
||||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND c.membership = ?
|
||||
AND {clause}
|
||||
"""
|
||||
|
||||
txn.execute(sql, [Membership.JOIN] + args)
|
||||
|
||||
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
|
||||
user_id: set() for user_id in user_ids
|
||||
}
|
||||
for user_id, room_id, instance, stream_id in txn:
|
||||
result[user_id].add(
|
||||
GetRoomsForUserWithStreamOrdering(
|
||||
room_id, PersistedEventPosition(instance, stream_id)
|
||||
)
|
||||
)
|
||||
|
||||
return {user_id: frozenset(v) for user_id, v in result.items()}
|
||||
|
||||
async def get_users_server_still_shares_room_with(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Set[str]:
|
||||
|
@ -693,19 +634,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
return {row[0] for row in txn}
|
||||
|
||||
@cancellable
|
||||
async def get_rooms_for_user(
|
||||
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|
||||
) -> FrozenSet[str]:
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
participating in.
|
||||
"""
|
||||
rooms = await self.get_rooms_for_user_with_stream_ordering(
|
||||
user_id, on_invalidate=on_invalidate
|
||||
rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
|
||||
(user_id,),
|
||||
None,
|
||||
update_metrics=False,
|
||||
)
|
||||
return frozenset(r.room_id for r in rooms)
|
||||
if rooms:
|
||||
return frozenset(r.room_id for r in rooms)
|
||||
|
||||
room_ids = await self.db_pool.simple_select_onecol(
|
||||
table="current_state_events",
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"membership": Membership.JOIN,
|
||||
"state_key": user_id,
|
||||
},
|
||||
retcol="room_id",
|
||||
desc="get_rooms_for_user",
|
||||
)
|
||||
|
||||
return frozenset(room_ids)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_rooms_for_user",
|
||||
list_name="user_ids",
|
||||
)
|
||||
async def get_rooms_for_users(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, FrozenSet[str]]:
|
||||
"""A batched version of `get_rooms_for_user`.
|
||||
|
||||
Returns:
|
||||
Map from user_id to set of rooms that is currently in.
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="current_state_events",
|
||||
column="state_key",
|
||||
iterable=user_ids,
|
||||
retcols=(
|
||||
"state_key",
|
||||
"room_id",
|
||||
),
|
||||
keyvalues={
|
||||
"type": EventTypes.Member,
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
desc="get_rooms_for_users",
|
||||
)
|
||||
|
||||
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
|
||||
|
||||
for row in rows:
|
||||
user_rooms[row["state_key"]].add(row["room_id"])
|
||||
|
||||
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def does_pair_of_users_share_a_room(
|
||||
|
@ -936,7 +926,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return True
|
||||
|
||||
@cached(iterable=True, max_entries=10000)
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state."""
|
||||
|
||||
# First we check if we already have `get_users_in_room` in the cache, as
|
||||
# we can just calculate result from that
|
||||
users = self.get_users_in_room.cache.get_immediate(
|
||||
(room_id,), None, update_metrics=False
|
||||
)
|
||||
if users is not None:
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
# If we're using SQLite then let's just always use
|
||||
# `get_users_in_room` rather than funky SQL.
|
||||
users = await self.get_users_in_room(room_id)
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||
# joined users in `current_state_events` via regex.
|
||||
|
||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
sql = """
|
||||
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||
FROM current_state_events
|
||||
WHERE
|
||||
type = 'm.room.member'
|
||||
AND membership = 'join'
|
||||
AND room_id = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
return {d for d, in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||
)
|
||||
|
||||
@cached(iterable=True, max_entries=10000)
|
||||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
|
||||
"""
|
||||
Get current hosts in room based on current state.
|
||||
|
||||
|
@ -944,48 +971,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
longest is good because they're most likely to have anything we ask
|
||||
about.
|
||||
|
||||
Uses `m.room.member`s in the room state at the current forward extremities to
|
||||
determine which hosts are in the room.
|
||||
For SQLite the returned list is not ordered, as SQLite doesn't support
|
||||
the appropriate SQL.
|
||||
|
||||
Will return inaccurate results for rooms with partial state, since the state for
|
||||
the forward extremities of those rooms will exclude most members. We may also
|
||||
calculate room state incorrectly for such rooms and believe that a host is or
|
||||
is not in the room when the opposite is true.
|
||||
Uses `m.room.member`s in the room state at the current forward
|
||||
extremities to determine which hosts are in the room.
|
||||
|
||||
Will return inaccurate results for rooms with partial state, since the
|
||||
state for the forward extremities of those rooms will exclude most
|
||||
members. We may also calculate room state incorrectly for such rooms and
|
||||
believe that a host is or is not in the room when the opposite is true.
|
||||
|
||||
Returns:
|
||||
Returns a list of servers sorted by longest in the room first. (aka.
|
||||
sorted by join with the lowest depth first).
|
||||
"""
|
||||
|
||||
# First we check if we already have `get_users_in_room` in the cache, as
|
||||
# we can just calculate result from that
|
||||
users = self.get_users_in_room.cache.get_immediate(
|
||||
(room_id,), None, update_metrics=False
|
||||
)
|
||||
if users is None and isinstance(self.database_engine, Sqlite3Engine):
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
# If we're using SQLite then let's just always use
|
||||
# `get_users_in_room` rather than funky SQL.
|
||||
users = await self.get_users_in_room(room_id)
|
||||
|
||||
if users is not None:
|
||||
# Because `users` is sorted from lowest -> highest depth, the list
|
||||
# of domains will also be sorted that way.
|
||||
domains: List[str] = []
|
||||
# We use a `Set` just for fast lookups
|
||||
domain_set: Set[str] = set()
|
||||
for u in users:
|
||||
if ":" not in u:
|
||||
continue
|
||||
domain = get_domain_from_id(u)
|
||||
if domain not in domain_set:
|
||||
domain_set.add(domain)
|
||||
domains.append(domain)
|
||||
return domains
|
||||
domains = await self.get_current_hosts_in_room(room_id)
|
||||
return list(domains)
|
||||
|
||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||
# joined users in `current_state_events` via regex.
|
||||
|
||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
|
||||
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
|
||||
# Returns a list of servers currently joined in the room sorted by
|
||||
# longest in the room first (aka. with the lowest depth). The
|
||||
# heuristic of sorting by servers who have been in the room the
|
||||
|
@ -1013,7 +1025,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return [d for d, in txn if d is not None]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
||||
)
|
||||
|
||||
async def get_joined_hosts(
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a column to the users table to track whether the user needs to be approved by an
|
||||
-- administrator.
|
||||
-- A NULL column means the user was created before this feature was supported by Synapse,
|
||||
-- and should be considered as TRUE.
|
||||
ALTER TABLE users ADD COLUMN approved BOOLEAN;
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Stores remote device lists we have received for remote users while a partial
|
||||
-- join is in progress.
|
||||
--
|
||||
-- This allows us to replay any device list updates if it turns out the remote
|
||||
-- user was in the partially joined room
|
||||
CREATE TABLE device_lists_remote_pending(
|
||||
stream_id BIGINT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- We only keep the most recent update for a given user/device pair.
|
||||
CREATE UNIQUE INDEX device_lists_remote_pending_user_device_id ON device_lists_remote_pending(user_id, device_id);
|
|
@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _should_ignore_runtime_requirement(req: Requirement) -> bool:
|
||||
# This is a build-time dependency. Irritatingly, `poetry build` ignores the
|
||||
# requirements listed in the [build-system] section of pyproject.toml, so in order
|
||||
# to support `poetry install --no-dev` we have to mark it as a runtime dependency.
|
||||
# See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds
|
||||
# like the poetry authors don't consider this a bug?)
|
||||
#
|
||||
# In any case, workaround this by ignoring setuptools_rust here. (It might be
|
||||
# slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for
|
||||
# now let's do something quick and dirty.
|
||||
if req.name == "setuptools_rust":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Dependency(NamedTuple):
|
||||
requirement: Requirement
|
||||
must_be_installed: bool
|
||||
|
@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]:
|
|||
assert requirements is not None
|
||||
for raw_requirement in requirements:
|
||||
req = Requirement(raw_requirement)
|
||||
if _is_dev_dependency(req):
|
||||
if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req):
|
||||
continue
|
||||
|
||||
# https://packaging.pypa.io/en/latest/markers.html#usage notes that
|
||||
|
|
|
@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
# Blow away caches (supported room versions can only change due to a restart).
|
||||
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
|
||||
self.store.get_rooms_for_user.invalidate_all()
|
||||
self.get_success(self.store._get_event_cache.clear())
|
||||
self.store._event_ref.clear()
|
||||
|
||||
|
|
|
@ -19,16 +19,18 @@ import frozendict
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.push import push_rule_evaluator
|
||||
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
from synapse.push.bulk_push_rule_evaluator import _flatten_dict
|
||||
from synapse.push.httppusher import tweaks_for_actions
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, register, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||
from synapse.types import JsonDict
|
||||
from synapse.synapse_rust.push import PushRuleEvaluator
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -41,7 +43,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
|||
content: JsonDict,
|
||||
relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
|
||||
relations_match_enabled: bool = False,
|
||||
) -> PushRuleEvaluatorForEvent:
|
||||
) -> PushRuleEvaluator:
|
||||
event = FrozenEvent(
|
||||
{
|
||||
"event_id": "$event_id",
|
||||
|
@ -56,12 +58,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
|||
room_member_count = 0
|
||||
sender_power_level = 0
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
||||
return PushRuleEvaluatorForEvent(
|
||||
event,
|
||||
return PushRuleEvaluator(
|
||||
_flatten_dict(event),
|
||||
room_member_count,
|
||||
sender_power_level,
|
||||
power_levels,
|
||||
relations or set(),
|
||||
power_levels.get("notifications", {}),
|
||||
relations or {},
|
||||
relations_match_enabled,
|
||||
)
|
||||
|
||||
|
@ -293,7 +295,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
|||
]
|
||||
|
||||
self.assertEqual(
|
||||
push_rule_evaluator.tweaks_for_actions(actions),
|
||||
tweaks_for_actions(actions),
|
||||
{"sound": "default", "highlight": True},
|
||||
)
|
||||
|
||||
|
@ -304,9 +306,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
|||
evaluator = self._get_evaluator(
|
||||
{}, {"m.annotation": {("@user:test", "m.reaction")}}
|
||||
)
|
||||
condition = {"kind": "relation_match"}
|
||||
# Oddly, an unknown condition always matches.
|
||||
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
|
||||
|
||||
# A push rule evaluator with the experimental rule enabled.
|
||||
evaluator = self._get_evaluator(
|
||||
|
@ -439,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(len(users_with_push_actions), 0)
|
||||
|
||||
|
||||
class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self.main_store = homeserver.get_datastores().main
|
||||
|
||||
self.user_id1 = self.register_user("user1", "password")
|
||||
self.tok1 = self.login(self.user_id1, "password")
|
||||
self.user_id2 = self.register_user("user2", "password")
|
||||
self.tok2 = self.login(self.user_id2, "password")
|
||||
|
||||
self.room_id = self.helper.create_room_as(tok=self.tok1)
|
||||
|
||||
# We want to test history visibility works correctly.
|
||||
self.helper.send_state(
|
||||
self.room_id,
|
||||
EventTypes.RoomHistoryVisibility,
|
||||
{"history_visibility": HistoryVisibility.JOINED},
|
||||
tok=self.tok1,
|
||||
)
|
||||
|
||||
def get_notif_count(self, user_id: str) -> int:
|
||||
return self.get_success(
|
||||
self.main_store.db_pool.simple_select_one_onecol(
|
||||
table="event_push_actions",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="COALESCE(SUM(notif), 0)",
|
||||
desc="get_staging_notif_count",
|
||||
)
|
||||
)
|
||||
|
||||
def test_plain_message(self) -> None:
|
||||
"""Test that sending a normal message in a room will trigger a
|
||||
notification
|
||||
"""
|
||||
|
||||
# Have user2 join the room and cle
|
||||
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
|
||||
|
||||
# They start off with no notifications, but get them when messages are
|
||||
# sent.
|
||||
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||
|
||||
user1 = UserID.from_string(self.user_id1)
|
||||
self.create_and_send_event(self.room_id, user1)
|
||||
|
||||
self.assertEqual(self.get_notif_count(self.user_id2), 1)
|
||||
|
||||
def test_delayed_message(self) -> None:
|
||||
"""Test that a delayed message that was from before a user joined
|
||||
doesn't cause a notification for the joined user.
|
||||
"""
|
||||
user1 = UserID.from_string(self.user_id1)
|
||||
|
||||
# Send a message before user2 joins
|
||||
event_id1 = self.create_and_send_event(self.room_id, user1)
|
||||
|
||||
# Have user2 join the room
|
||||
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
|
||||
|
||||
# They start off with no notifications
|
||||
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||
|
||||
# Send another message that references the event before the join to
|
||||
# simulate a "delayed" event
|
||||
self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
|
||||
|
||||
# user2 should not be notified about it, because they can't see it.
|
||||
self.assertEqual(self.get_notif_count(self.user_id2), 0)
|
||||
|
|
|
@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
|
||||
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.rest.client import devices, login, logout, profile, room, sync
|
||||
from synapse.rest.client import devices, login, logout, profile, register, room, sync
|
||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
|||
_search_test(None, "foo", "user_id")
|
||||
_search_test(None, "bar", "user_id")
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_invalid_parameter(self) -> None:
|
||||
"""
|
||||
If parameters are invalid, an error is returned.
|
||||
|
@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid approved
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?approved=not_bool",
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# unkown order_by
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
|
@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
|||
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
|
||||
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_filter_out_approved(self) -> None:
|
||||
"""Tests that the endpoint can filter out approved users."""
|
||||
# Create our users.
|
||||
self._create_users(2)
|
||||
|
||||
# Get the list of users.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
# Exclude the admin, because we don't want to accidentally un-approve the admin.
|
||||
non_admin_user_ids = [
|
||||
user["name"]
|
||||
for user in channel.json_body["users"]
|
||||
if user["name"] != self.admin_user
|
||||
]
|
||||
|
||||
self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
|
||||
|
||||
# Select a user and un-approve them. We do this rather than the other way around
|
||||
# because, since these users are created by an admin, we consider them already
|
||||
# approved.
|
||||
not_approved_user = non_admin_user_ids[0]
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"/_synapse/admin/v2/users/{not_approved_user}",
|
||||
{"approved": False},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
# Now get the list of users again, this time filtering out approved users.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?approved=false",
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
non_admin_user_ids = [
|
||||
user["name"]
|
||||
for user in channel.json_body["users"]
|
||||
if user["name"] != self.admin_user
|
||||
]
|
||||
|
||||
# We should only have our unapproved user now.
|
||||
self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
|
||||
self.assertEqual(not_approved_user, non_admin_user_ids[0])
|
||||
|
||||
def _order_test(
|
||||
self,
|
||||
expected_user_list: List[str],
|
||||
|
@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
|
@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||
# Ensure they're still alive
|
||||
self.assertEqual(0, channel.json_body["deactivated"])
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_approve_account(self) -> None:
|
||||
"""Tests that approving an account correctly sets the approved flag for the user."""
|
||||
url = self.url_prefix % "@bob:test"
|
||||
|
||||
# Create the user using the client-server API since otherwise the user will be
|
||||
# marked as approved automatically.
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{
|
||||
"username": "bob",
|
||||
"password": "test",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
},
|
||||
)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||
)
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertIs(False, channel.json_body["approved"])
|
||||
|
||||
# Approve user
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"approved": True},
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertIs(True, channel.json_body["approved"])
|
||||
|
||||
# Check that the user is now approved
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertIs(True, channel.json_body["approved"])
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_register_approved(self) -> None:
|
||||
url = self.url_prefix % "@bob:test"
|
||||
|
||||
# Create user
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
content={"password": "abc123", "approved": True},
|
||||
)
|
||||
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual(1, channel.json_body["approved"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
url,
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual(1, channel.json_body["approved"])
|
||||
|
||||
def _is_erased(self, user_id: str, expect: bool) -> None:
|
||||
"""Assert that the user is erased or not"""
|
||||
d = self.store.is_user_erased(user_id)
|
||||
|
|
|
@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from twisted.web.resource import Resource
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.rest.client import account, auth, devices, login, logout, register
|
||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||
|
@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
body={"auth": {"session": session_id}},
|
||||
)
|
||||
|
||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||
@override_config(
|
||||
{
|
||||
"oidc_config": TEST_OIDC_CONFIG,
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_sso_not_approved(self) -> None:
|
||||
"""Tests that if we register a user via SSO while requiring approval for new
|
||||
accounts, we still raise the correct error before logging the user in.
|
||||
"""
|
||||
login_resp = self.helper.login_via_oidc("username", expected_status=403)
|
||||
|
||||
self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
|
||||
)
|
||||
|
||||
# Check that we didn't register a device for the user during the login attempt.
|
||||
devices = self.get_success(
|
||||
self.hs.get_datastores().main.get_devices_by_user("@username:test")
|
||||
)
|
||||
|
||||
self.assertEqual(len(devices), 0)
|
||||
|
||||
|
||||
class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
|
|
@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from twisted.web.resource import Resource
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client import devices, login, logout, register
|
||||
from synapse.rest.client.account import WhoamiRestServlet
|
||||
|
@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|||
logout.register_servlets,
|
||||
devices.register_servlets,
|
||||
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
|
@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_require_approval(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
},
|
||||
)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||
)
|
||||
|
||||
params = {
|
||||
"type": LoginType.PASSWORD,
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
channel = self.make_request("POST", LOGIN_URL, params)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||
)
|
||||
|
||||
|
||||
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
||||
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
|
|
|
@ -22,6 +22,8 @@ from synapse.util import Clock
|
|||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
|
||||
|
||||
|
||||
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
@ -45,18 +47,18 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.password = "password"
|
||||
|
||||
def test_disabled(self) -> None:
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||
def test_require_auth(self) -> None:
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||
|
@ -64,7 +66,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
|||
user_id = self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||
self.assertEqual(channel.code, 401)
|
||||
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||
|
||||
|
@ -79,7 +81,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
}
|
||||
|
||||
channel = self.make_request("POST", "/login/token", uia, access_token=token)
|
||||
channel = self.make_request("POST", endpoint, uia, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||
|
||||
|
@ -100,7 +102,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
|||
user_id = self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||
|
||||
|
@ -127,6 +129,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
channel = self.make_request("POST", endpoint, {}, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 15)
|
||||
|
|
|
@ -22,7 +22,11 @@ import pkg_resources
|
|||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
|
||||
from synapse.api.constants import (
|
||||
APP_SERVICE_REGISTRATION_TYPE,
|
||||
ApprovalNoticeMedium,
|
||||
LoginType,
|
||||
)
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client import account, account_validity, login, logout, register, sync
|
||||
|
@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, 400, channel.json_body)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_require_approval(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
},
|
||||
)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
|
||||
)
|
||||
|
||||
|
||||
class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
|
|
@ -543,8 +543,12 @@ class RestHelper:
|
|||
|
||||
return channel.json_body
|
||||
|
||||
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||
"""Log in (as a new user) via OIDC
|
||||
def login_via_oidc(
|
||||
self,
|
||||
remote_user_id: str,
|
||||
expected_status: int = 200,
|
||||
) -> JsonDict:
|
||||
"""Log in via OIDC
|
||||
|
||||
Returns the result of the final token login.
|
||||
|
||||
|
@ -578,7 +582,9 @@ class RestHelper:
|
|||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
assert channel.code == HTTPStatus.OK
|
||||
assert (
|
||||
channel.code == expected_status
|
||||
), f"unexpected status in response: {channel.code}"
|
||||
return channel.json_body
|
||||
|
||||
def auth_via_oidc(
|
||||
|
|
|
@ -754,18 +754,28 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
def test_get_backfill_points_in_room(self):
|
||||
"""
|
||||
Test to make sure we get some backfill points
|
||||
Test to make sure only backfill points that are older and come before
|
||||
the `current_depth` are returned.
|
||||
"""
|
||||
setup_info = self._setup_room_for_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Try at "B"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id)
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(
|
||||
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
|
||||
self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"])
|
||||
|
||||
# Try at "A"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
# Event "2" has a depth of 2 but is not included here because we only
|
||||
# know the approximate depth of 5 from our event "3".
|
||||
self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||
|
||||
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
|
||||
self,
|
||||
|
@ -776,6 +786,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
"""
|
||||
setup_info = self._setup_room_for_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Record some attempts to backfill these events which will make
|
||||
# `get_backfill_points_in_room` exclude them because we
|
||||
|
@ -795,12 +806,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
# No time has passed since we attempted to backfill ^
|
||||
|
||||
# Try at "B"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id)
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
# Only the backfill points that we didn't record earlier exist here.
|
||||
self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"])
|
||||
self.assertEqual(backfill_event_ids, ["b6", "2", "b1"])
|
||||
|
||||
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
|
||||
self,
|
||||
|
@ -812,6 +824,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
"""
|
||||
setup_info = self._setup_room_for_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Record some attempts to backfill these events which will make
|
||||
# `get_backfill_points_in_room` exclude them because we
|
||||
|
@ -839,27 +852,66 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
# visible regardless.
|
||||
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
||||
|
||||
# Make sure that "b1" is not in the list because we've
|
||||
# Try at "A" and make sure that "b1" is not in the list because we've
|
||||
# already attempted many times
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id)
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"])
|
||||
self.assertEqual(backfill_event_ids, ["b3", "b2"])
|
||||
|
||||
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
||||
# see if we can now backfill it
|
||||
self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
|
||||
|
||||
# Try again after we advanced enough time and we should see "b3" again
|
||||
# Try at "A" again after we advanced enough time and we should see "b3" again
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id)
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(
|
||||
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
|
||||
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||
|
||||
def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
A test that reproduces #13929 (Postgres only).
|
||||
|
||||
Test to make sure we can still get backfill points after many failed pull
|
||||
attempts that cause us to backoff to the limit. Even if the backoff formula
|
||||
would tell us to wait for more seconds than can be expressed in a 32 bit
|
||||
signed int.
|
||||
"""
|
||||
setup_info = self._setup_room_for_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Pretend that we have tried and failed 10 times to backfill event b1.
|
||||
for _ in range(10):
|
||||
self.get_success(
|
||||
self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
|
||||
)
|
||||
|
||||
# If the backoff periods grow without limit:
|
||||
# After the first failed attempt, we would have backed off for 1 << 1 = 2 hours.
|
||||
# After the second failed attempt we would have backed off for 1 << 2 = 4 hours,
|
||||
# so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours.
|
||||
# Wait 1100 hours just so we have a nice round number.
|
||||
self.reactor.advance(datetime.timedelta(hours=1100).total_seconds())
|
||||
|
||||
# 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit
|
||||
# signed integer. The bug we're reproducing is that this overflow causes an
|
||||
# error in postgres preventing us from fetching a set of backwards extremities
|
||||
# to retry fetching.
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
|
||||
)
|
||||
|
||||
# We should aim to fetch all backoff points: b1's latest backoff period has
|
||||
# expired, and we haven't tried the rest.
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
|
||||
|
||||
def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
|
||||
"""
|
||||
Sets up a room with various insertion event backward extremities to test
|
||||
|
@ -938,18 +990,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
def test_get_insertion_event_backward_extremities_in_room(self):
|
||||
"""
|
||||
Test to make sure insertion event backward extremities are returned.
|
||||
Test to make sure only insertion event backward extremities that are
|
||||
older and come before the `current_depth` are returned.
|
||||
"""
|
||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Try at "insertion_eventB"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
||||
self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id, depth_map["insertion_eventB"], limit=100
|
||||
)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(
|
||||
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
|
||||
self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
|
||||
|
||||
# Try at "insertion_eventA"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id, depth_map["insertion_eventA"], limit=100
|
||||
)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
# Event "2" has a depth of 2 but is not included here because we only
|
||||
# know the approximate depth of 5 from our event "3".
|
||||
self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
|
||||
|
||||
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
|
||||
self,
|
||||
|
@ -961,6 +1027,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
"""
|
||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Record some attempts to backfill these events which will make
|
||||
# `get_insertion_event_backward_extremities_in_room` exclude them
|
||||
|
@ -973,12 +1040,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
# No time has passed since we attempted to backfill ^
|
||||
|
||||
# Try at "insertion_eventB"
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
||||
self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id, depth_map["insertion_eventB"], limit=100
|
||||
)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
# Only the backfill points that we didn't record earlier exist here.
|
||||
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
|
||||
self.assertEqual(backfill_event_ids, ["insertion_eventB"])
|
||||
|
||||
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
|
||||
self,
|
||||
|
@ -991,6 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
"""
|
||||
setup_info = self._setup_room_for_insertion_backfill_tests()
|
||||
room_id = setup_info.room_id
|
||||
depth_map = setup_info.depth_map
|
||||
|
||||
# Record some attempts to backfill these events which will make
|
||||
# `get_backfill_points_in_room` exclude them because we
|
||||
|
@ -1027,13 +1098,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
# because we haven't waited long enough for this many attempts.
|
||||
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
|
||||
|
||||
# Make sure that "insertion_eventA" is not in the list because we've
|
||||
# already attempted many times
|
||||
# Try at "insertion_eventA" and make sure that "insertion_eventA" is not
|
||||
# in the list because we've already attempted many times
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
||||
self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id, depth_map["insertion_eventA"], limit=100
|
||||
)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
|
||||
self.assertEqual(backfill_event_ids, [])
|
||||
|
||||
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
|
||||
# see if we can now backfill it
|
||||
|
@ -1042,12 +1115,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
# Try at "insertion_eventA" again after we advanced enough time and we
|
||||
# should see "insertion_eventA" again
|
||||
backfill_points = self.get_success(
|
||||
self.store.get_insertion_event_backward_extremities_in_room(room_id)
|
||||
self.store.get_insertion_event_backward_extremities_in_room(
|
||||
room_id, depth_map["insertion_eventA"], limit=100
|
||||
)
|
||||
)
|
||||
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
|
||||
self.assertListEqual(
|
||||
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
|
||||
)
|
||||
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
|
||||
|
||||
|
||||
@attr.s
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Collection, Optional
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.types import UserID, create_requester
|
||||
|
@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def get_last_unthreaded_receipt(
|
||||
self, receipt_types: Collection[str], room_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
|
||||
|
||||
Args:
|
||||
receipt_types: The receipt types to fetch.
|
||||
|
||||
Returns:
|
||||
The latest receipt, if one exists.
|
||||
"""
|
||||
result = self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"get_last_receipt_event_id_for_user",
|
||||
self.store.get_last_unthreaded_receipt_for_user_txn,
|
||||
OUR_USER_ID,
|
||||
room_id or self.room_id1,
|
||||
receipt_types,
|
||||
)
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
event_id, _ = result
|
||||
return event_id
|
||||
|
||||
def test_return_empty_with_no_data(self) -> None:
|
||||
res = self.get_success(
|
||||
self.store.get_receipts_for_user(
|
||||
|
@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(res, {})
|
||||
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID,
|
||||
self.room_id1,
|
||||
[
|
||||
ReceiptTypes.READ,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
],
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt(
|
||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
||||
)
|
||||
|
||||
self.assertEqual(res, None)
|
||||
|
||||
def test_get_receipts_for_user(self) -> None:
|
||||
|
@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Test we get the latest event when we want both private and public receipts
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID,
|
||||
self.room_id1,
|
||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt(
|
||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
|
||||
)
|
||||
self.assertEqual(res, event1_2_id)
|
||||
|
||||
# Test we get the older event when we want only public receipt
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
|
||||
)
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
|
||||
self.assertEqual(res, event1_1_id)
|
||||
|
||||
# Test we get the latest event when we want only the private receipt
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
|
||||
)
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
|
||||
self.assertEqual(res, event1_2_id)
|
||||
|
||||
# Test receipt updating
|
||||
|
@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
|||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
|
||||
)
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
|
||||
self.assertEqual(res, event1_2_id)
|
||||
|
||||
# Send some events into the second room
|
||||
|
@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
|||
{},
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
self.store.get_last_receipt_event_id_for_user(
|
||||
OUR_USER_ID,
|
||||
self.room_id2,
|
||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
|
||||
)
|
||||
res = self.get_last_unthreaded_receipt(
|
||||
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
|
||||
)
|
||||
self.assertEqual(res, event2_1_id)
|
||||
|
|
|
@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
class RegistrationStoreTestCase(HomeserverTestCase):
|
||||
|
@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
"user_type": None,
|
||||
"deactivated": 0,
|
||||
"shadow_banned": 0,
|
||||
"approved": 1,
|
||||
},
|
||||
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
||||
)
|
||||
|
@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
ThreepidValidationError,
|
||||
)
|
||||
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
|
||||
|
||||
|
||||
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
|
||||
# If there's already some config for this feature in the default config, it
|
||||
# means we're overriding it with @override_config. In this case we don't want
|
||||
# to do anything more with it.
|
||||
msc3866_config = config.get("experimental_features", {}).get("msc3866")
|
||||
if msc3866_config is not None:
|
||||
return config
|
||||
|
||||
# Require approval for all new accounts.
|
||||
config["experimental_features"] = {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": True,
|
||||
}
|
||||
}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.user_id = "@my-user:test"
|
||||
self.pwhash = "{xx1}123456789"
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3866": {
|
||||
"enabled": True,
|
||||
"require_approval_for_new_accounts": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_approval_not_required(self) -> None:
|
||||
"""Tests that if we don't require approval for new accounts, newly created
|
||||
accounts are automatically marked as approved.
|
||||
"""
|
||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertTrue(user["approved"])
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
||||
def test_approval_required(self) -> None:
|
||||
"""Tests that if we require approval for new accounts, newly created accounts
|
||||
are not automatically marked as approved.
|
||||
"""
|
||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertFalse(user["approved"])
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertFalse(approved)
|
||||
|
||||
def test_override(self) -> None:
|
||||
"""Tests that if we require approval for new accounts, but we explicitly say the
|
||||
new user should be considered approved, they're marked as approved.
|
||||
"""
|
||||
self.get_success(
|
||||
self.store.register_user(
|
||||
self.user_id,
|
||||
self.pwhash,
|
||||
approved=True,
|
||||
)
|
||||
)
|
||||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
self.assertIsNotNone(user)
|
||||
assert user is not None
|
||||
self.assertEqual(user["approved"], 1)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
||||
def test_approve_user(self) -> None:
|
||||
"""Tests that approving the user updates their approval status."""
|
||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertFalse(approved)
|
||||
|
||||
self.get_success(
|
||||
self.store.update_user_approval_status(
|
||||
UserID.from_string(self.user_id), True
|
||||
)
|
||||
)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
|
|
@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase):
|
|||
def mock_installed_package(
|
||||
self, distribution: Optional[DummyDistribution]
|
||||
) -> Generator[None, None, None]:
|
||||
"""Pretend that looking up any distribution yields the given `distribution`."""
|
||||
"""Pretend that looking up any package yields the given `distribution`.
|
||||
|
||||
If `distribution = None`, we pretend that the package is not installed.
|
||||
"""
|
||||
|
||||
def mock_distribution(name: str):
|
||||
if distribution is None:
|
||||
|
@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase):
|
|||
self.assertRaises(DependencyException, check_requirements)
|
||||
|
||||
def test_checks_ignore_dev_dependencies(self) -> None:
|
||||
"""Bot generic and per-extra checks should ignore dev dependencies."""
|
||||
"""Both generic and per-extra checks should ignore dev dependencies."""
|
||||
with patch(
|
||||
"synapse.util.check_dependencies.metadata.requires",
|
||||
return_value=["dummypkg >= 1; extra == 'mypy'"],
|
||||
|
@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase):
|
|||
with self.mock_installed_package(new_release_candidate):
|
||||
# should not raise
|
||||
check_requirements()
|
||||
|
||||
def test_setuptools_rust_ignored(self) -> None:
|
||||
"""Test a workaround for a `poetry build` problem. Reproduces #13926."""
|
||||
with patch(
|
||||
"synapse.util.check_dependencies.metadata.requires",
|
||||
return_value=["setuptools_rust >= 1.3"],
|
||||
):
|
||||
with self.mock_installed_package(None):
|
||||
# should not raise, even if setuptools_rust is not installed
|
||||
check_requirements()
|
||||
with self.mock_installed_package(old):
|
||||
# We also ignore old versions of setuptools_rust
|
||||
check_requirements()
|
||||
|
|
Loading…
Reference in New Issue