Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

dmr/revert-fts-changes-on-hotfix
Erik Johnston 2022-09-30 14:27:14 +01:00
commit a2b6ee7b00
92 changed files with 2781 additions and 861 deletions

View File

@ -0,0 +1 @@
Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)).

1
changelog.d/13719.bugfix Normal file
View File

@ -0,0 +1 @@
Send invite push notifications for invite over federation.

1
changelog.d/13787.misc Normal file
View File

@ -0,0 +1 @@
Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar).

1
changelog.d/13838.misc Normal file
View File

@ -0,0 +1 @@
Port push rules to using Rust.

1
changelog.d/13868.misc Normal file
View File

@ -0,0 +1 @@
Fix unstable MSC3882 endpoint being incorrectly available on stable API versions.

1
changelog.d/13879.misc Normal file
View File

@ -0,0 +1 @@
Only pull relevant backfill points from the database based on the current depth and limit (instead of all) every time we want to `/backfill`.

1
changelog.d/13890.misc Normal file
View File

@ -0,0 +1 @@
Improve backfill robustness by trying more servers when we get a `4xx` error back.

1
changelog.d/13904.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in 1.66 where some required fields in the pushrules sent to clients were not present anymore. Contributed by Nico.

1
changelog.d/13913.misc Normal file
View File

@ -0,0 +1 @@
Faster remote room joins: correctly handle remote device list updates during a partial join.

1
changelog.d/13924.misc Normal file
View File

@ -0,0 +1 @@
Update an innaccurate comment in Synapse's upsert database helper.

1
changelog.d/13928.doc Normal file
View File

@ -0,0 +1 @@
Add instruction to contributing guide for running unit tests in parallel. Contributed by @ashfame.

1
changelog.d/13930.doc Normal file
View File

@ -0,0 +1 @@
Update the man page for the `hash_password` script to correct the default number of bcrypt rounds performed.

1
changelog.d/13931.doc Normal file
View File

@ -0,0 +1 @@
Clarify that the `auto_join_rooms` config option can also be used with Space aliases.

View File

@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

1
changelog.d/13934.misc Normal file
View File

@ -0,0 +1 @@
Correctly handle sending local device list updates to remote servers during a partial join.

View File

@ -0,0 +1 @@
Exponentially backoff from backfilling the same event over and over.

View File

@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

View File

@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

View File

@ -0,0 +1 @@
Add cache invalidation across workers to module API.

1
changelog.d/13952.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in v1.68.0 where Synapse would require `setuptools_rust` at runtime, even though the package is only required at build time.

View File

@ -0,0 +1 @@
Ask mail servers receiving emails from Synapse to not send automatic reply (e.g. out-of-office responses).

1
changelog.d/13972.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a performance regression in the `get_users_in_room` database query. Introduced in v1.67.0.

1
changelog.d/13973.misc Normal file
View File

@ -0,0 +1 @@
Speed up calculating push actions in large rooms.

1
changelog.d/13974.doc Normal file
View File

@ -0,0 +1 @@
Add some cross references to worker documentation.

View File

@ -10,7 +10,7 @@
.P
\fBhash_password\fR takes a password as an parameter either on the command line or the \fBSTDIN\fR if not supplied\.
.P
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB10\fR\.
It accepts an YAML file which can be used to specify parameters like the number of rounds for bcrypt and password_config section having the pepper value used for the hashing\. By default \fBbcrypt_rounds\fR is set to \fB12\fR\.
.P
The hashed password is written on the \fBSTDOUT\fR\.
.SH "FILES"

View File

@ -167,6 +167,12 @@ was broken. They are slower than the linters but will typically catch more error
poetry run trial tests
```
You can run unit tests in parallel by specifying `-jX` argument to `trial` where `X` is the number of parallel runners you want. To use 4 cpu cores, you would run them like:
```sh
poetry run trial -j4 tests
```
If you wish to only run *some* unit tests, you may specify
another module instead of `tests` - or a test class or a method:

View File

@ -0,0 +1,14 @@
worker_app: synapse.app.media_repository
worker_name: media_worker
# The replication listener on the main synapse process.
worker_replication_host: 127.0.0.1
worker_replication_http_port: 9093
worker_listeners:
- type: http
port: 8085
resources:
- names: [media]
worker_log_config: /etc/matrix-synapse/media-worker-log.yaml

View File

@ -88,6 +88,18 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
```
# Upgrading to v1.69.0
## Changes to the receipts replication streams
Synapse now includes information indicating if a receipt applies to a thread when
replicating it to other workers. This is a forwards- and backwards-incompatible
change: v1.68 and workers cannot process receipts replicated by v1.69 workers, and
vice versa.
Once all workers are upgraded to v1.69 (or downgraded to v1.68), receipts
replication will resume as normal.
# Upgrading to v1.68.0
Two changes announced in the upgrade notes for v1.67.0 have now landed in v1.68.0.

View File

@ -2229,6 +2229,9 @@ homeserver. If the room already exists, make certain it is a publicly joinable
room, i.e. the join rule of the room must be set to 'public'. You can find more options
relating to auto-joining rooms below.
As Spaces are just rooms under the hood, Space aliases may also be
used.
Example configuration:
```yaml
auto_join_rooms:
@ -2240,7 +2243,7 @@ auto_join_rooms:
Where `auto_join_rooms` are specified, setting this flag ensures that
the rooms exist by creating them when the first user on the
homeserver registers.
homeserver registers. This option will not create Spaces.
By default the auto-created rooms are publicly joinable from any federated
server. Use the `autocreate_auto_join_rooms_federated` and
@ -2258,7 +2261,7 @@ autocreate_auto_join_rooms: false
---
### `autocreate_auto_join_rooms_federated`
Whether the rooms listen in `auto_join_rooms` that are auto-created are available
Whether the rooms listed in `auto_join_rooms` that are auto-created are available
via federation. Only has an effect if `autocreate_auto_join_rooms` is true.
Note that whether a room is federated cannot be modified after

View File

@ -93,7 +93,6 @@ listener" for the main process; and secondly, you need to enable redis-based
replication. Optionally, a shared secret can be used to authenticate HTTP
traffic between workers. For example:
```yaml
# extend the existing `listeners` section. This defines the ports that the
# main process will listen on.
@ -129,7 +128,8 @@ In the config file for each worker, you must specify:
* The HTTP replication endpoint that it should talk to on the main synapse process
(`worker_replication_host` and `worker_replication_http_port`)
* If handling HTTP requests, a `worker_listeners` option with an `http`
listener, in the same way as the `listeners` option in the shared config.
listener, in the same way as the [`listeners`](usage/configuration/config_documentation.md#listeners)
option in the shared config.
* If handling the `^/_matrix/client/v3/keys/upload` endpoint, the HTTP URI for
the main process (`worker_main_http_uri`).
@ -285,8 +285,9 @@ For multiple workers not handling the SSO endpoints properly, see
[#7530](https://github.com/matrix-org/synapse/issues/7530) and
[#9427](https://github.com/matrix-org/synapse/issues/9427).
Note that a HTTP listener with `client` and `federation` resources must be
configured in the `worker_listeners` option in the worker config.
Note that a [HTTP listener](usage/configuration/config_documentation.md#listeners)
with `client` and `federation` `resources` must be configured in the `worker_listeners`
option in the worker config.
#### Load balancing
@ -326,7 +327,8 @@ effects of bursts of events from that bridge on events sent by normal users.
Additionally, the writing of specific streams (such as events) can be moved off
of the main process to a particular worker.
To enable this, the worker must have a HTTP replication listener configured,
To enable this, the worker must have a
[HTTP `replication` listener](usage/configuration/config_documentation.md#listeners) configured,
have a `worker_name` and be listed in the `instance_map` config. The same worker
can handle multiple streams, but unless otherwise documented, each stream can only
have a single writer.
@ -410,7 +412,7 @@ the stream writer for the `presence` stream:
There is also support for moving background tasks to a separate
worker. Background tasks are run periodically or started via replication. Exactly
which tasks are configured to run depends on your Synapse configuration (e.g. if
stats is enabled).
stats is enabled). This worker doesn't handle any REST endpoints itself.
To enable this, the worker must have a `worker_name` and can be configured to run
background tasks. For example, to move background tasks to a dedicated worker,
@ -457,8 +459,8 @@ worker application type.
#### Notifying Application Services
You can designate one generic worker to send output traffic to Application Services.
Specify its name in the shared configuration as follows:
Doesn't handle any REST endpoints itself, but you should specify its name in the
shared configuration as follows:
```yaml
notify_appservices_from_worker: worker_name
@ -536,16 +538,12 @@ file to stop the main synapse running background jobs related to managing the
media repository. Note that doing so will prevent the main process from being
able to handle the above endpoints.
In the `media_repository` worker configuration file, configure the http listener to
In the `media_repository` worker configuration file, configure the
[HTTP listener](usage/configuration/config_documentation.md#listeners) to
expose the `media` resource. For example:
```yaml
worker_listeners:
- type: http
port: 8085
resources:
- names:
- media
{{#include systemd-with-workers/workers/media_worker.yaml}}
```
Note that if running multiple media repositories they must be on the same server

View File

@ -11,7 +11,9 @@ rust-version = "1.58.1"
[lib]
name = "synapse"
crate-type = ["cdylib"]
# We generate a `cdylib` for Python and a standard `lib` for running
# tests/benchmarks.
crate-type = ["lib", "cdylib"]
[package.metadata.maturin]
# This is where we tell maturin where to place the built library.

149
rust/benches/evaluator.rs Normal file
View File

@ -0,0 +1,149 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(test)]
use synapse::push::{
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
};
use test::Bencher;
extern crate test;
#[bench]
fn bench_match_exact(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
]
.into_iter()
.collect();
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
0,
Default::default(),
Default::default(),
true,
)
.unwrap();
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
EventMatchCondition {
key: "room_id".into(),
pattern: Some("!room:server".into()),
pattern_type: None,
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
}
#[bench]
fn bench_match_word(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
]
.into_iter()
.collect();
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
0,
Default::default(),
Default::default(),
true,
)
.unwrap();
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
EventMatchCondition {
key: "content.body".into(),
pattern: Some("test".into()),
pattern_type: None,
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
}
#[bench]
fn bench_match_word_miss(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
]
.into_iter()
.collect();
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
0,
Default::default(),
Default::default(),
true,
)
.unwrap();
let condition = Condition::Known(synapse::push::KnownCondition::EventMatch(
EventMatchCondition {
key: "content.body".into(),
pattern: Some("foobar".into()),
pattern_type: None,
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
assert!(!matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
}
#[bench]
fn bench_eval_message(b: &mut Bencher) {
let flattened_keys = [
("type".to_string(), "m.text".to_string()),
("room_id".to_string(), "!room:server".to_string()),
("content.body".to_string(), "test message".to_string()),
]
.into_iter()
.collect();
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
0,
Default::default(),
Default::default(),
true,
)
.unwrap();
let rules =
FilteredPushRules::py_new(PushRules::new(Vec::new()), Default::default(), false, false);
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
}

40
rust/benches/glob.rs Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(test)]
use synapse::push::utils::{glob_to_regex, GlobMatchType};
use test::Bencher;
extern crate test;
#[bench]
fn bench_whole(b: &mut Bencher) {
b.iter(|| glob_to_regex("test", GlobMatchType::Whole));
}
#[bench]
fn bench_word(b: &mut Bencher) {
b.iter(|| glob_to_regex("test", GlobMatchType::Word));
}
#[bench]
fn bench_whole_wildcard_run(b: &mut Bencher) {
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
}
#[bench]
fn bench_word_wildcard_run(b: &mut Bencher) {
b.iter(|| glob_to_regex("test***??*?*?foo", GlobMatchType::Whole));
}

View File

@ -22,7 +22,7 @@ fn main() -> Result<(), std::io::Error> {
for entry in entries {
if entry.is_dir() {
dirs.push(entry)
dirs.push(entry);
} else {
paths.push(entry.to_str().expect("valid rust paths").to_string());
}

View File

@ -262,6 +262,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch {
rel_type: Cow::Borrowed("m.thread"),
event_type_pattern: None,
sender: None,
sender_type: Some(Cow::Borrowed("user_id")),
})]),

374
rust/src/push/evaluator.rs Normal file
View File

@ -0,0 +1,374 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet},
};
use anyhow::{Context, Error};
use lazy_static::lazy_static;
use log::warn;
use pyo3::prelude::*;
use regex::Regex;
use super::{
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
};
lazy_static! {
/// Used to parse the `is` clause in the room member count condition.
static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]+)$").expect("valid regex");
}
/// Allows running a set of push rules against a particular event.
#[pyclass]
pub struct PushRuleEvaluator {
/// A mapping of "flattened" keys to string values in the event, e.g.
/// includes things like "type" and "content.msgtype".
flattened_keys: BTreeMap<String, String>,
/// The "content.body", if any.
body: String,
/// The number of users in the room.
room_member_count: u64,
/// The `notifications` section of the current power levels in the room.
notification_power_levels: BTreeMap<String, i64>,
/// The relations related to the event as a mapping from relation type to
/// set of sender/event type 2-tuples.
relations: BTreeMap<String, BTreeSet<(String, String)>>,
/// Is running "relation" conditions enabled?
relation_match_enabled: bool,
/// The power level of the sender of the event, or None if event is an
/// outlier.
sender_power_level: Option<i64>,
}
#[pymethods]
impl PushRuleEvaluator {
/// Create a new `PushRuleEvaluator`. See struct docstring for details.
#[new]
pub fn py_new(
flattened_keys: BTreeMap<String, String>,
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
relations: BTreeMap<String, BTreeSet<(String, String)>>,
relation_match_enabled: bool,
) -> Result<Self, Error> {
let body = flattened_keys
.get("content.body")
.cloned()
.unwrap_or_default();
Ok(PushRuleEvaluator {
flattened_keys,
body,
room_member_count,
notification_power_levels,
relations,
relation_match_enabled,
sender_power_level,
})
}
/// Run the evaluator with the given push rules, for the given user ID and
/// display name of the user.
///
/// Passing in None will skip evaluating rules matching user ID and display
/// name.
///
/// Returns the set of actions, if any, that match (filtering out any
/// `dont_notify` actions).
pub fn run(
&self,
push_rules: &FilteredPushRules,
user_id: Option<&str>,
display_name: Option<&str>,
) -> Vec<Action> {
'outer: for (push_rule, enabled) in push_rules.iter() {
if !enabled {
continue;
}
for condition in push_rule.conditions.iter() {
match self.match_condition(condition, user_id, display_name) {
Ok(true) => {}
Ok(false) => continue 'outer,
Err(err) => {
warn!("Condition match failed {err}");
continue 'outer;
}
}
}
let actions = push_rule
.actions
.iter()
// Filter out "dont_notify" actions, as we don't store them.
.filter(|a| **a != Action::DontNotify)
.cloned()
.collect();
return actions;
}
Vec::new()
}
/// Check if the given condition matches.
fn matches(
&self,
condition: Condition,
user_id: Option<&str>,
display_name: Option<&str>,
) -> bool {
match self.match_condition(&condition, user_id, display_name) {
Ok(true) => true,
Ok(false) => false,
Err(err) => {
warn!("Condition match failed {err}");
false
}
}
}
}
impl PushRuleEvaluator {
/// Match a given `Condition` for a push rule.
pub fn match_condition(
&self,
condition: &Condition,
user_id: Option<&str>,
display_name: Option<&str>,
) -> Result<bool, Error> {
let known_condition = match condition {
Condition::Known(known) => known,
Condition::Unknown(_) => {
return Ok(false);
}
};
let result = match known_condition {
KnownCondition::EventMatch(event_match) => {
self.match_event_match(event_match, user_id)?
}
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
get_glob_matcher(dn, GlobMatchType::Word)?.is_match(&self.body)?
} else {
// We specifically ignore empty display names, as otherwise
// they would always match.
false
}
} else {
false
}
}
KnownCondition::RoomMemberCount { is } => {
if let Some(is) = is {
self.match_member_count(is)?
} else {
false
}
}
KnownCondition::SenderNotificationPermission { key } => {
if let Some(sender_power_level) = &self.sender_power_level {
let required_level = self
.notification_power_levels
.get(key.as_ref())
.copied()
.unwrap_or(50);
*sender_power_level >= required_level
} else {
false
}
}
KnownCondition::RelationMatch {
rel_type,
event_type_pattern,
sender,
sender_type,
} => {
self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)?
}
};
Ok(result)
}
/// Evaluates a relation condition.
fn match_relations(
&self,
rel_type: &str,
sender: &Option<Cow<str>>,
sender_type: &Option<Cow<str>>,
user_id: Option<&str>,
event_type_pattern: &Option<Cow<str>>,
) -> Result<bool, Error> {
// First check if relation matching is enabled...
if !self.relation_match_enabled {
return Ok(false);
}
// ... and if there are any relations to match against.
let relations = if let Some(relations) = self.relations.get(rel_type) {
relations
} else {
return Ok(false);
};
// Extract the sender pattern from the condition
let sender_pattern = if let Some(sender) = sender {
Some(sender.as_ref())
} else if let Some(sender_type) = sender_type {
if sender_type == "user_id" {
if let Some(user_id) = user_id {
Some(user_id)
} else {
return Ok(false);
}
} else {
warn!("Unrecognized sender_type: {sender_type}");
return Ok(false);
}
} else {
None
};
let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern {
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
} else {
None
};
let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern {
Some(get_glob_matcher(pattern, GlobMatchType::Whole)?)
} else {
None
};
for (relation_sender, event_type) in relations {
if let Some(pattern) = &mut sender_compiled_pattern {
if !pattern.is_match(relation_sender)? {
continue;
}
}
if let Some(pattern) = &mut type_compiled_pattern {
if !pattern.is_match(event_type)? {
continue;
}
}
return Ok(true);
}
Ok(false)
}
/// Evaluates a `event_match` condition.
fn match_event_match(
&self,
event_match: &EventMatchCondition,
user_id: Option<&str>,
) -> Result<bool, Error> {
let pattern = if let Some(pattern) = &event_match.pattern {
pattern
} else if let Some(pattern_type) = &event_match.pattern_type {
// The `pattern_type` can either be "user_id" or "user_localpart",
// either way if we don't have a `user_id` then the condition can't
// match.
let user_id = if let Some(user_id) = user_id {
user_id
} else {
return Ok(false);
};
match &**pattern_type {
"user_id" => user_id,
"user_localpart" => get_localpart_from_id(user_id)?,
_ => return Ok(false),
}
} else {
return Ok(false);
};
let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) {
haystack
} else {
return Ok(false);
};
// For the content.body we match against "words", but for everything
// else we match against the entire value.
let match_type = if event_match.key == "content.body" {
GlobMatchType::Word
} else {
GlobMatchType::Whole
};
let mut compiled_pattern = get_glob_matcher(pattern, match_type)?;
compiled_pattern.is_match(haystack)
}
/// Match the member count against an 'is' condition
/// The `is` condition can be things like '>2', '==3' or even just '4'.
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
let captures = INEQUALITY_EXPR.captures(is).context("bad 'is' clause")?;
let ineq = captures.get(1).map_or("==", |m| m.as_str());
let rhs: u64 = captures
.get(2)
.context("missing number")?
.as_str()
.parse()?;
let matches = match ineq {
"" | "==" => self.room_member_count == rhs,
"<" => self.room_member_count < rhs,
">" => self.room_member_count > rhs,
">=" => self.room_member_count >= rhs,
"<=" => self.room_member_count <= rhs,
_ => false,
};
Ok(matches)
}
}
#[test]
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
BTreeMap::new(),
BTreeMap::new(),
true,
)
.unwrap();
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
assert_eq!(result.len(), 3);
}

View File

@ -42,7 +42,6 @@
//!
//! The set of "base rules" are the list of rules that every user has by default. A
//! user can modify their copy of the push rules in one of three ways:
//!
//! 1. Adding a new push rule of a certain kind
//! 2. Changing the actions of a base rule
//! 3. Enabling/disabling a base rule.
@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet};
use anyhow::{Context, Error};
use log::warn;
use pyo3::prelude::*;
use pythonize::pythonize;
use pythonize::{depythonize, pythonize};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use self::evaluator::PushRuleEvaluator;
mod base_rules;
pub mod evaluator;
pub mod utils;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
child_module.add_class::<PushRule>()?;
child_module.add_class::<PushRules>()?;
child_module.add_class::<FilteredPushRules>()?;
child_module.add_class::<PushRuleEvaluator>()?;
child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
m.add_submodule(child_module)?;
@ -274,6 +278,8 @@ pub enum KnownCondition {
#[serde(rename = "org.matrix.msc3772.relation_match")]
RelationMatch {
rel_type: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
event_type_pattern: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
sender: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
@ -287,20 +293,26 @@ impl IntoPy<PyObject> for Condition {
}
}
impl<'source> FromPyObject<'source> for Condition {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(depythonize(ob)?)
}
}
/// The body of a [`Condition::EventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EventMatchCondition {
key: Cow<'static, str>,
pub key: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none")]
pattern: Option<Cow<'static, str>>,
pub pattern: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pattern_type: Option<Cow<'static, str>>,
pub pattern_type: Option<Cow<'static, str>>,
}
/// The collection of push rules for a user.
#[derive(Debug, Clone, Default)]
#[pyclass(frozen)]
struct PushRules {
pub struct PushRules {
/// Custom push rules that override a base rule.
overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
@ -319,7 +331,7 @@ struct PushRules {
#[pymethods]
impl PushRules {
#[new]
fn new(rules: Vec<PushRule>) -> PushRules {
pub fn new(rules: Vec<PushRule>) -> PushRules {
let mut push_rules: PushRules = Default::default();
for rule in rules {
@ -396,7 +408,7 @@ pub struct FilteredPushRules {
#[pymethods]
impl FilteredPushRules {
#[new]
fn py_new(
pub fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3786_enabled: bool,

215
rust/src/push/utils.rs Normal file
View File

@ -0,0 +1,215 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::bail;
use anyhow::Context;
use anyhow::Error;
use lazy_static::lazy_static;
use regex;
use regex::Regex;
use regex::RegexBuilder;
lazy_static! {
/// Matches runs of non-wildcard characters followed by wildcard characters.
static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex");
}
/// Extract the localpart from a Matrix style ID
pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> {
let (localpart, _) = id
.split_once(':')
.with_context(|| format!("ID does not contain colon: {id}"))?;
// We need to strip off the first character, which is the ID type.
if localpart.is_empty() {
bail!("Invalid ID {id}");
}
Ok(&localpart[1..])
}
/// Used by `glob_to_regex` to specify what to match the regex against.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GlobMatchType {
/// The generated regex will match against the entire input.
Whole,
/// The generated regex will match against words.
Word,
}
/// Convert a "glob" style expression to a regex, anchoring either to the entire
/// input or to individual words.
pub fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result<Regex, Error> {
let mut chunks = Vec::new();
// Patterns with wildcards must be simplified to avoid performance cliffs
// - The glob `?**?**?` is equivalent to the glob `???*`
// - The glob `???*` is equivalent to the regex `.{3,}`
for captures in WILDCARD_RUN.captures_iter(glob) {
if let Some(chunk) = captures.get(1) {
chunks.push(regex::escape(chunk.as_str()));
}
if let Some(wildcards) = captures.get(2) {
if wildcards.as_str() == "" {
continue;
}
let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count();
if wildcards.as_str().contains('*') {
chunks.push(format!(".{{{question_marks},}}"));
} else {
chunks.push(format!(".{{{question_marks}}}"));
}
}
}
let joined = chunks.join("");
let regex_str = match match_type {
GlobMatchType::Whole => format!(r"\A{joined}\z"),
// `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word
// character.
GlobMatchType::Word => format!(r"(?:^|\b|\W){joined}(?:\b|\W|$)"),
};
Ok(RegexBuilder::new(&regex_str)
.case_insensitive(true)
.build()?)
}
/// Compiles the glob into a `Matcher`.
pub fn get_glob_matcher(glob: &str, match_type: GlobMatchType) -> Result<Matcher, Error> {
// There are a number of shortcuts we can make if the glob doesn't contain a
// wild card.
let matcher = if glob.contains(['*', '?']) {
let regex = glob_to_regex(glob, match_type)?;
Matcher::Regex(regex)
} else if match_type == GlobMatchType::Whole {
// If there aren't any wildcards and we're matching the whole thing,
// then we simply can do a case-insensitive string match.
Matcher::Whole(glob.to_lowercase())
} else {
// Otherwise, if we're matching against words then can first check
// if the haystack contains the glob at all.
Matcher::Word {
word: glob.to_lowercase(),
regex: None,
}
};
Ok(matcher)
}
/// Matches against a glob
pub enum Matcher {
/// Plain regex matching.
Regex(Regex),
/// Case-insensitive equality.
Whole(String),
/// Word matching. `regex` is a cache of calling [`glob_to_regex`] on word.
Word { word: String, regex: Option<Regex> },
}
impl Matcher {
/// Checks if the glob matches the given haystack.
pub fn is_match(&mut self, haystack: &str) -> Result<bool, Error> {
// We want to to do case-insensitive matching, so we convert to
// lowercase first.
let haystack = haystack.to_lowercase();
match self {
Matcher::Regex(regex) => Ok(regex.is_match(&haystack)),
Matcher::Whole(whole) => Ok(whole == &haystack),
Matcher::Word { word, regex } => {
// If we're looking for a literal word, then we first check if
// the haystack contains the word as a substring.
if !haystack.contains(&*word) {
return Ok(false);
}
// If it does contain the word as a substring, then we need to
// check if it is an actual word by testing it against the regex.
let regex = if let Some(regex) = regex {
regex
} else {
let compiled_regex = glob_to_regex(word, GlobMatchType::Word)?;
regex.insert(compiled_regex)
};
Ok(regex.is_match(&haystack))
}
}
}
}
#[test]
fn test_get_domain_from_id() {
get_localpart_from_id("").unwrap_err();
get_localpart_from_id(":").unwrap_err();
get_localpart_from_id(":asd").unwrap_err();
get_localpart_from_id("::as::asad").unwrap_err();
assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test");
assert_eq!(get_localpart_from_id("@:").unwrap(), "");
assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test");
}
#[test]
fn tset_glob() -> Result<(), Error> {
assert_eq!(
glob_to_regex("simple", GlobMatchType::Whole)?.as_str(),
r"\Asimple\z"
);
assert_eq!(
glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{0,}\z"
);
assert_eq!(
glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{1}\z"
);
assert_eq!(
glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{2,}\z"
);
assert_eq!(
glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{3}\z"
);
assert_eq!(
glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(),
r"\Aescape\.\z"
);
assert!(glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simple"));
assert!(!glob_to_regex("simple", GlobMatchType::Whole)?.is_match("simples"));
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simples"));
assert!(glob_to_regex("simple?", GlobMatchType::Whole)?.is_match("simples"));
assert!(glob_to_regex("simple*", GlobMatchType::Whole)?.is_match("simple"));
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple."));
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple"));
assert!(!glob_to_regex("simple", GlobMatchType::Word)?.is_match("simples"));
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("Some @user:foo test"));
assert!(glob_to_regex("@user:foo", GlobMatchType::Word)?.is_match("@user:foo"));
Ok(())
}

View File

@ -1,4 +1,4 @@
from typing import Any, Collection, Dict, Mapping, Sequence, Tuple, Union
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
from synapse.types import JsonDict
@ -35,3 +35,20 @@ class FilteredPushRules:
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
def get_base_rule_ids() -> Collection[str]: ...
class PushRuleEvaluator:
def __init__(
self,
flattened_keys: Mapping[str, str],
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
relations: Mapping[str, Set[Tuple[str, str]]],
relation_match_enabled: bool,
): ...
def run(
self,
push_rules: FilteredPushRules,
user_id: Optional[str],
display_name: Optional[str],
) -> Collection[dict]: ...

View File

@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"],
"users": ["shadow_banned", "approved"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"],

View File

@ -269,3 +269,14 @@ class PublicRoomsFilterFields:
GENERIC_SEARCH_TERM: Final = "generic_search_term"
ROOM_TYPES: Final = "room_types"
class ApprovalNoticeMedium:
"""Identifier for the medium this server will use to serve notice of approval for a
specific user's registration.
As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md
"""
NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"

View File

@ -106,6 +106,8 @@ class Codes(str, Enum):
# Part of MSC3895.
UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE"
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.
@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError):
return cs_error(self.msg, self.errcode, **extra)
class NotApprovedError(SynapseError):
def __init__(
self,
msg: str,
approval_notice_medium: str,
):
super().__init__(
code=403,
msg=msg,
errcode=Codes.USER_AWAITING_APPROVAL,
additional_fields={"approval_notice_medium": approval_notice_medium},
)
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server
interactions.

View File

@ -14,10 +14,25 @@
from typing import Any
import attr
from synapse.config._base import Config
from synapse.types import JsonDict
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MSC3866Config:
"""Configuration for MSC3866 (mandating approval for new users)"""
# Whether the base support for the approval process is enabled. This includes the
# ability for administrators to check and update the approval of users, even if no
# approval is currently required.
enabled: bool = False
# Whether to require that new users are approved by an admin before their account
# can be used. Note that this setting is ignored if 'enabled' is false.
require_approval_for_new_accounts: bool = False
class ExperimentalConfig(Config):
"""Config section for enabling experimental features"""
@ -97,6 +112,10 @@ class ExperimentalConfig(Config):
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
# MSC3866: M_USER_AWAITING_APPROVAL error code
raw_msc3866_config = experimental.get("msc3866", {})
self.msc3866 = MSC3866Config(**raw_msc3866_config)
# MSC3881: Remotely toggle push notifications for another client
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)

View File

@ -289,6 +289,10 @@ class _EventInternalMetadata:
"""
return self._dict.get("historical", False)
def is_notifiable(self) -> bool:
"""Whether this event can trigger a push notification"""
return not self.is_outlier() or self.is_out_of_band_membership()
class EventBase(metaclass=abc.ABCMeta):
@property

View File

@ -32,6 +32,7 @@ class AdminHandler:
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
@ -75,6 +76,10 @@ class AdminHandler:
"is_guest",
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set.
user_info_dict = {
key: value

View File

@ -1009,6 +1009,17 @@ class AuthHandler:
return res[0]
return None
async def is_user_approved(self, user_id: str) -> bool:
"""Checks if a user is approved and therefore can be allowed to log in.
Args:
user_id: the user to check the approval status of.
Returns:
A boolean that is True if the user is approved, False otherwise.
"""
return await self.store.is_user_approved(user_id)
async def _find_user_id_and_pwd_hash(
self, user_id: str
) -> Optional[Tuple[str, str]]:

View File

@ -273,11 +273,9 @@ class DeviceWorkerHandler:
possibly_left = possibly_changed | possibly_left
# Double check if we still share rooms with the given user.
users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
possibly_left
)
users_rooms = await self.store.get_rooms_for_users(possibly_left)
for changed_user_id, entries in users_rooms.items():
if any(e.room_id in room_ids for e in entries):
if any(rid in room_ids for rid in entries):
possibly_left.discard(changed_user_id)
else:
possibly_joined.discard(changed_user_id)
@ -309,6 +307,17 @@ class DeviceWorkerHandler:
"self_signing_key": self_signing_key,
}
async def handle_room_un_partial_stated(self, room_id: str) -> None:
"""Handles sending appropriate device list updates in a room that has
gone from partial to full state.
"""
# TODO(faster_joins): worker mode support
# https://github.com/matrix-org/synapse/issues/12994
logger.error(
"Trying handling device list state for partial join: not supported on workers."
)
class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs: "HomeServer"):
@ -746,6 +755,95 @@ class DeviceHandler(DeviceWorkerHandler):
finally:
self._handle_new_device_update_is_processing = False
async def handle_room_un_partial_stated(self, room_id: str) -> None:
"""Handles sending appropriate device list updates in a room that has
gone from partial to full state.
"""
# We defer to the device list updater to handle pending remote device
# list updates.
await self.device_list_updater.handle_room_un_partial_stated(room_id)
# Replay local updates.
(
join_event_id,
device_lists_stream_id,
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
room_id
)
# Get the local device list changes that have happened in the room since
# we started joining. If there are no updates there's nothing left to do.
changes = await self.store.get_device_list_changes_in_room(
room_id, device_lists_stream_id
)
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
if not local_changes:
return
# Note: We have persisted the full state at this point, we just haven't
# cleared the `partial_room` flag.
join_state_ids = await self._state_storage.get_state_ids_for_event(
join_event_id, await_full_state=False
)
current_state_ids = await self.store.get_partial_current_state_ids(room_id)
# Now we need to work out all servers that might have been in the room
# at any point during our join.
# First we look for any membership states that have changed between the
# initial join and now...
all_keys = set(join_state_ids)
all_keys.update(current_state_ids)
potentially_changed_hosts = set()
for etype, state_key in all_keys:
if etype != EventTypes.Member:
continue
prev = join_state_ids.get((etype, state_key))
current = current_state_ids.get((etype, state_key))
if prev != current:
potentially_changed_hosts.add(get_domain_from_id(state_key))
# ... then we add all the hosts that are currently joined to the room...
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
potentially_changed_hosts.update(current_hosts_in_room)
# ... and finally we remove any hosts that we were told about, as we
# will have sent device list updates to those hosts when they happened.
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
potentially_changed_hosts.difference_update(known_hosts_at_join)
potentially_changed_hosts.discard(self.server_name)
if not potentially_changed_hosts:
# Nothing to do.
return
logger.info(
"Found %d changed hosts to send device list updates to",
len(potentially_changed_hosts),
)
for user_id, device_id in local_changes:
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)
# Notify things that device lists need to be sent out.
self.notifier.notify_replication()
for host in potentially_changed_hosts:
self.federation_sender.send_device_messages(host, immediate=False)
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@ -836,6 +934,16 @@ class DeviceListUpdater:
)
return
# Check if we are partially joining any rooms. If so we need to store
# all device list updates so that we can handle them correctly once we
# know who is in the room.
partial_rooms = await self.store.get_partial_state_rooms_and_servers()
if partial_rooms:
await self.store.add_remote_device_list_to_pending(
user_id,
device_id,
)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
@ -1175,3 +1283,35 @@ class DeviceListUpdater:
device_ids.append(verify_key.version)
return device_ids
async def handle_room_un_partial_stated(self, room_id: str) -> None:
"""Handles sending appropriate device list updates in a room that has
gone from partial to full state.
"""
pending_updates = (
await self.store.get_pending_remote_device_list_updates_for_room(room_id)
)
for user_id, device_id in pending_updates:
logger.info(
"Got pending device list update in room %s: %s / %s",
room_id,
user_id,
device_id,
)
position = await self.store.add_device_change_to_streams(
user_id,
[device_id],
room_ids=[room_id],
)
if not position:
# This should only happen if there are no updates, which
# shouldn't happen when we've passed in a non-empty set of
# device IDs.
continue
self.device_handler.notifier.on_new_event(
StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
)

View File

@ -38,7 +38,7 @@ from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from synapse import event_auth
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
AuthError,
CodeMessageException,
@ -149,6 +149,8 @@ class FederationHandler:
self.http_client = hs.get_proxied_blacklisted_http_client()
self._replication = hs.get_replication_data_handler()
self._federation_event_handler = hs.get_federation_event_handler()
self._device_handler = hs.get_device_handler()
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
@ -209,7 +211,7 @@ class FederationHandler:
current_depth: int,
limit: int,
*,
processing_start_time: int,
processing_start_time: Optional[int],
) -> bool:
"""
Checks whether the `current_depth` is at or approaching any backfill
@ -221,12 +223,23 @@ class FederationHandler:
room_id: The room to backfill in.
current_depth: The depth to check at for any upcoming backfill points.
limit: The max number of events to request from the remote federated server.
processing_start_time: The time when `maybe_backfill` started
processing. Only used for timing.
processing_start_time: The time when `maybe_backfill` started processing.
Only used for timing. If `None`, no timing observation will be made.
"""
backwards_extremities = [
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
for event_id, depth in await self.store.get_backfill_points_in_room(room_id)
for event_id, depth in await self.store.get_backfill_points_in_room(
room_id=room_id,
current_depth=current_depth,
# We only need to end up with 5 extremities combined with the
# insertion event extremities to make the `/backfill` request
# but fetch an order of magnitude more to make sure there is
# enough even after we filter them by whether visible in the
# history. This isn't fool-proof as all backfill points within
# our limit could be filtered out but seems like a good amount
# to try with at least.
limit=50,
)
]
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
@ -234,7 +247,12 @@ class FederationHandler:
insertion_events_to_be_backfilled = [
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
room_id
room_id=room_id,
current_depth=current_depth,
# We only need to end up with 5 extremities combined with
# the backfill points to make the `/backfill` request ...
# (see the other comment above for more context).
limit=50,
)
]
logger.debug(
@ -243,10 +261,6 @@ class FederationHandler:
insertion_events_to_be_backfilled,
)
if not backwards_extremities and not insertion_events_to_be_backfilled:
logger.debug("Not backfilling as no extremeties found.")
return False
# we now have a list of potential places to backpaginate from. We prefer to
# start with the most recent (ie, max depth), so let's sort the list.
sorted_backfill_points: List[_BackfillPoint] = sorted(
@ -267,6 +281,33 @@ class FederationHandler:
sorted_backfill_points,
)
# If we have no backfill points lower than the `current_depth` then
# either we can a) bail or b) still attempt to backfill. We opt to try
# backfilling anyway just in case we do get relevant events.
if not sorted_backfill_points and current_depth != MAX_DEPTH:
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
)
return await self._maybe_backfill_inner(
room_id=room_id,
# We use `MAX_DEPTH` so that we find all backfill points next
# time (all events are below the `MAX_DEPTH`)
current_depth=MAX_DEPTH,
limit=limit,
# We don't want to start another timing observation from this
# nested recursive call. The top-most call can record the time
# overall otherwise the smaller one will throw off the results.
processing_start_time=None,
)
# Even after recursing with `MAX_DEPTH`, we didn't find any
# backward extremities to backfill from.
if not sorted_backfill_points:
logger.debug(
"_maybe_backfill_inner: Not backfilling as no backward extremeties found."
)
return False
# If we're approaching an extremity we trigger a backfill, otherwise we
# no-op.
#
@ -276,47 +317,16 @@ class FederationHandler:
# chose more than one times the limit in case of failure, but choosing a
# much larger factor will result in triggering a backfill request much
# earlier than necessary.
#
# XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
# care about events that have happened after our current position.
#
max_depth = sorted_backfill_points[0].depth
if current_depth - 2 * limit > max_depth:
max_depth_of_backfill_points = sorted_backfill_points[0].depth
if current_depth - 2 * limit > max_depth_of_backfill_points:
logger.debug(
"Not backfilling as we don't need to. %d < %d - 2 * %d",
max_depth,
max_depth_of_backfill_points,
current_depth,
limit,
)
return False
# We ignore extremities that have a greater depth than our current depth
# as:
# 1. we don't really care about getting events that have happened
# after our current position; and
# 2. we have likely previously tried and failed to backfill from that
# extremity, so to avoid getting "stuck" requesting the same
# backfill repeatedly we drop those extremities.
#
# However, we need to check that the filtered extremities are non-empty.
# If they are empty then either we can a) bail or b) still attempt to
# backfill. We opt to try backfilling anyway just in case we do get
# relevant events.
#
filtered_sorted_backfill_points = [
t for t in sorted_backfill_points if t.depth <= current_depth
]
if filtered_sorted_backfill_points:
logger.debug(
"_maybe_backfill_inner: backfill points before current depth: %s",
filtered_sorted_backfill_points,
)
sorted_backfill_points = filtered_sorted_backfill_points
else:
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
)
# For performance's sake, we only want to paginate from a particular extremity
# if we can actually see the events we'll get. Otherwise, we'd just spend a lot
# of resources to get redacted events. We check each extremity in turn and
@ -402,11 +412,22 @@ class FederationHandler:
# First we try hosts that are already in the room.
# TODO: HEURISTIC ALERT.
likely_domains = (
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
room_id
)
)
async def try_backfill(domains: Collection[str]) -> bool:
# TODO: Should we try multiple of these at a time?
# Number of contacted remote homeservers that have denied our backfill
# request with a 4xx code.
denied_count = 0
# Maximum number of contacted remote homeservers that can deny our
# backfill request with 4xx codes before we give up.
max_denied_count = 5
for dom in domains:
# We don't want to ask our own server for information we don't have
if dom == self.server_name:
@ -425,13 +446,33 @@ class FederationHandler:
continue
except HttpResponseException as e:
if 400 <= e.code < 500:
raise e.to_synapse_error()
logger.warning(
"Backfill denied from %s because %s [%d/%d]",
dom,
e,
denied_count,
max_denied_count,
)
denied_count += 1
if denied_count >= max_denied_count:
return False
continue
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.warning(
"Backfill denied from %s because %s [%d/%d]",
dom,
e,
denied_count,
max_denied_count,
)
denied_count += 1
if denied_count >= max_denied_count:
return False
continue
logger.info("Failed to backfill from %s because %s", dom, e)
continue
@ -450,10 +491,15 @@ class FederationHandler:
return False
processing_end_time = self.clock.time_msec()
backfill_processing_before_timer.observe(
(processing_end_time - processing_start_time) / 1000
)
# If we have the `processing_start_time`, then we can make an
# observation. We wouldn't have the `processing_start_time` in the case
# where `_maybe_backfill_inner` is recursively called to find any
# backfill points regardless of `current_depth`.
if processing_start_time is not None:
processing_end_time = self.clock.time_msec()
backfill_processing_before_timer.observe(
(processing_end_time - processing_start_time) / 1000
)
success = await try_backfill(likely_domains)
if success:
@ -956,9 +1002,15 @@ class FederationHandler:
)
context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
try:
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
except Exception:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
return event
@ -1624,6 +1676,9 @@ class FederationHandler:
# https://github.com/matrix-org/synapse/issues/12994
await self.state_handler.update_current_state(room_id)
logger.info("Handling any pending device list updates")
await self._device_handler.handle_room_un_partial_stated(room_id)
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
if success:

View File

@ -2170,6 +2170,7 @@ class FederationEventHandler:
if instance != self._instance_name:
# Limit the number of events sent over replication. We choose 200
# here as that is what we default to in `max_request_body_size(..)`
result = {}
try:
for batch in batch_iter(event_and_contexts, 200):
result = await self._send_events(

View File

@ -220,6 +220,7 @@ class RegistrationHandler:
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None,
approved: bool = False,
) -> str:
"""Registers a new client on the server.
@ -246,6 +247,8 @@ class RegistrationHandler:
user_agent_ips: Tuples of user-agents and IP addresses used
during the registration process.
auth_provider_id: The SSO IdP the user used, if any.
approved: True if the new user should be considered already
approved by an administrator.
Returns:
The registered user_id.
Raises:
@ -307,6 +310,7 @@ class RegistrationHandler:
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
approved=approved,
)
profile = await self.store.get_profileinfo(localpart)
@ -695,6 +699,7 @@ class RegistrationHandler:
user_type: Optional[str] = None,
address: Optional[str] = None,
shadow_banned: bool = False,
approved: bool = False,
) -> None:
"""Register user in the datastore.
@ -713,6 +718,7 @@ class RegistrationHandler:
api.constants.UserTypes, or None for a normal user.
address: the IP address used to perform the registration.
shadow_banned: Whether to shadow-ban the user
approved: Whether to mark the user as approved by an administrator
"""
if self.hs.config.worker.worker_app:
await self._register_client(
@ -726,6 +732,7 @@ class RegistrationHandler:
user_type=user_type,
address=address,
shadow_banned=shadow_banned,
approved=approved,
)
else:
await self.store.register_user(
@ -738,6 +745,7 @@ class RegistrationHandler:
admin=admin,
user_type=user_type,
shadow_banned=shadow_banned,
approved=approved,
)
# Only call the account validity module(s) on the main process, to avoid

View File

@ -1540,7 +1540,9 @@ class TimestampLookupHandler:
)
likely_domains = (
await self._storage_controllers.state.get_current_hosts_in_room(room_id)
await self._storage_controllers.state.get_current_hosts_in_room_ordered(
room_id
)
)
# Loop through each homeserver candidate until we get a succesful response

View File

@ -187,6 +187,19 @@ class SendEmailHandler:
multipart_msg["To"] = email_address
multipart_msg["Date"] = email.utils.formatdate()
multipart_msg["Message-ID"] = email.utils.make_msgid()
# Discourage automatic responses to Synapse's emails.
# Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted"
# header is present with any value other than "no". See
# https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1
multipart_msg["Auto-Submitted"] = "auto-generated"
# Also include a Microsoft-Exchange specific header:
# https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1
# which suggests it can take the value "All" to "suppress all auto-replies",
# or a comma separated list of auto-reply classes to suppress.
# The following stack overflow question has a little more context:
# https://stackoverflow.com/a/25324691/5252017
# https://stackoverflow.com/a/61646381/5252017
multipart_msg["X-Auto-Response-Suppress"] = "All"
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)

View File

@ -1490,16 +1490,14 @@ class SyncHandler:
since_token.device_list_key
)
if changed_users is not None:
result = await self.store.get_rooms_for_users_with_stream_ordering(
changed_users
)
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
e.room_id in joined_rooms for e in entries
rid in joined_rooms for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
@ -1533,13 +1531,9 @@ class SyncHandler:
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
left_users_rooms = (
await self.store.get_rooms_for_users_with_stream_ordering(
newly_left_users
)
)
left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
for user_id, entries in left_users_rooms.items():
if any(e.room_id in joined_rooms for e in entries):
if any(rid in joined_rooms for rid in entries):
newly_left_users.discard(user_id)
return DeviceListUpdates(

View File

@ -842,6 +842,8 @@ class ModuleApi:
however invalidation that needs to go to other workers needs to call `invalidate_cache`
on the module API instead.
Added in Synapse v1.69.0.
Args:
cached_function: The cached function that will be registered to receive invalidation
locally and from other workers.
@ -856,6 +858,8 @@ class ModuleApi:
"""Invalidate a cache entry of a cached function across workers. The cached function
needs to be registered on all workers first with `register_cached_function`.
Added in Synapse v1.69.0.
Args:
cached_function: The cached function that needs an invalidation
keys: keys of the entry to invalidate, usually matching the arguments of the

View File

@ -17,6 +17,7 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
@ -37,13 +38,11 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
from synapse.synapse_rust.push import FilteredPushRules, PushRule
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -173,7 +172,11 @@ class BulkPushRuleEvaluator:
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
) -> Tuple[dict, Optional[int]]:
# There are no power levels and sender levels possible to get from outlier
if event.internal_metadata.is_outlier():
return {}, None
event_types = auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
@ -250,8 +253,8 @@ class BulkPushRuleEvaluator:
should increment the unread count, and insert the results into the
event_push_actions_staging table.
"""
if event.internal_metadata.is_outlier():
# This can happen due to out of band memberships
if not event.internal_metadata.is_notifiable():
# Push rules for events that aren't notifiable can't be processed by this
return
# Disable counting as unread unless the experimental configuration is
@ -286,11 +289,11 @@ class BulkPushRuleEvaluator:
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
evaluator = PushRuleEvaluatorForEvent(
event,
evaluator = PushRuleEvaluator(
_flatten_dict(event),
room_member_count,
sender_power_level,
power_levels,
power_levels.get("notifications", {}),
relations,
self._relations_match_enabled,
)
@ -300,20 +303,10 @@ class BulkPushRuleEvaluator:
event.room_id, users
)
# This is a check for the case where user joins a room without being
# allowed to see history, and then the server receives a delayed event
# from before the user joined, which they should not be pushed for
uids_with_visibility = await filter_event_for_clients_with_state(
self.store, users, event, context
)
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
if uid not in uids_with_visibility:
continue
display_name = None
profile = profiles.get(uid)
if profile:
@ -334,17 +327,25 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
for rule, enabled in rules.rules():
if not enabled:
continue
actions = evaluator.run(rules, uid, display_name)
if "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
matches = evaluator.check_conditions(rule.conditions, uid, display_name)
if matches:
actions = [x for x in rule.actions if x != "dont_notify"]
if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
# This is a check for the case where user joins a room without being
# allowed to see history, and then the server receives a delayed event
# from before the user joined, which they should not be pushed for
#
# We do this *after* calculating the push actions as a) its unlikely
# that we'll filter anyone out and b) for large rooms its likely that
# most users will have push disabled and so the set of users to check is
# much smaller.
uids_with_visibility = await filter_event_for_clients_with_state(
self.store, actions_by_user.keys(), event, context
)
for user_id in set(actions_by_user).difference(uids_with_visibility):
actions_by_user.pop(user_id, None)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
@ -361,3 +362,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
if prefix is None:
prefix = []
if result is None:
result = {}
for key, value in d.items():
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif isinstance(value, Mapping):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result

View File

@ -102,10 +102,8 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
# with PRIORITY_CLASS_INVERSE_MAP.
raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
if rule.default:
templaterule["default"] = True
templaterule["rule_id"] = unscoped_rule_id
templaterule["default"] = rule.default
return templaterule

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.storage.databases.main.event_push_actions import HttpPushAction
from . import push_rule_evaluator, push_tools
from . import push_tools
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -56,6 +56,39 @@ http_badges_failed_counter = Counter(
)
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
"""
Converts a list of actions into a `tweaks` dict (which can then be passed to
the push gateway).
This function ignores all actions other than `set_tweak` actions, and treats
absent `value`s as `True`, which agrees with the only spec-defined treatment
of absent `value`s (namely, for `highlight` tweaks).
Args:
actions: list of actions
e.g. [
{"set_tweak": "a", "value": "AAA"},
{"set_tweak": "b", "value": "BBB"},
{"set_tweak": "highlight"},
"notify"
]
Returns:
dictionary of tweaks for those actions
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
"""
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if "set_tweak" in a:
# value is allowed to be absent in which case the value assumed
# should be True.
tweaks[a["set_tweak"]] = a.get("value", True)
return tweaks
class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
@ -286,7 +319,7 @@ class HttpPusher(Pusher):
if "notify" not in push_action.actions:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
tweaks = tweaks_for_actions(push_action.actions)
badge = await push_tools.get_badge_count(
self.hs.get_datastores().main,
self.user_id,

View File

@ -1,361 +0,0 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from matrix_common.regex import glob_to_regex, to_word_pattern
from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(
ev: EventBase, condition: Mapping[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(
ev: EventBase,
condition: Mapping[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
notif_level_key = condition.get("key")
if notif_level_key is None:
return False
notif_levels = power_levels.get("notifications", {})
assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs_int = int(rhs)
if ineq == "" or ineq == "==":
return number == rhs_int
elif ineq == "<":
return number < rhs_int
elif ineq == ">":
return number > rhs_int
elif ineq == ">=":
return number >= rhs_int
elif ineq == "<=":
return number <= rhs_int
else:
return False
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
"""
Converts a list of actions into a `tweaks` dict (which can then be passed to
the push gateway).
This function ignores all actions other than `set_tweak` actions, and treats
absent `value`s as `True`, which agrees with the only spec-defined treatment
of absent `value`s (namely, for `highlight` tweaks).
Args:
actions: list of actions
e.g. [
{"set_tweak": "a", "value": "AAA"},
{"set_tweak": "b", "value": "BBB"},
{"set_tweak": "highlight"},
"notify"
]
Returns:
dictionary of tweaks for those actions
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
"""
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if "set_tweak" in a:
# value is allowed to be absent in which case the value assumed
# should be True.
tweaks[a["set_tweak"]] = a.get("value", True)
return tweaks
class PushRuleEvaluatorForEvent:
def __init__(
self,
event: EventBase,
room_member_count: int,
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
relations: Dict[str, Set[Tuple[str, str]]],
relations_match_enabled: bool,
):
self._event = event
self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
self._power_levels = power_levels
self._relations = relations
self._relations_match_enabled = relations_match_enabled
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
# Maps cache keys to final values.
self._condition_cache: Dict[str, bool] = {}
def check_conditions(
self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's conditions/user ID/display name match the event.
Args:
conditions: The user's conditions to match.
uid: The user's MXID.
display_name: The display name.
Returns:
True if all conditions match the event, False otherwise.
"""
for cond in conditions:
_cache_key = cond.get("_cache_key", None)
if _cache_key:
res = self._condition_cache.get(_cache_key, None)
if res is False:
return False
elif res is True:
continue
res = self.matches(cond, uid, display_name)
if _cache_key:
self._condition_cache[_cache_key] = bool(res)
if not res:
return False
return True
def matches(
self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
"""
Returns true if a user's condition/user ID/display name match the event.
Args:
condition: The user's condition to match.
uid: The user's MXID.
display_name: The display name, or None if there is not one.
Returns:
True if the condition matches the event, False otherwise.
"""
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
return self._contains_display_name(display_name)
elif condition["kind"] == "room_member_count":
return _room_member_count(self._event, condition, self._room_member_count)
elif condition["kind"] == "sender_notification_permission":
return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels
)
elif (
condition["kind"] == "org.matrix.msc3772.relation_match"
and self._relations_match_enabled
):
return self._relation_match(condition, user_id)
else:
# XXX This looks incorrect -- we have reached an unknown condition
# kind and are unconditionally returning that it matches. Note
# that it seems possible to provide a condition to the /pushrules
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True
def _event_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "event_match" push rule condition.
Args:
condition: The "event_match" push rule condition to match.
user_id: The user's MXID.
Returns:
True if the condition matches the event, False otherwise.
"""
pattern = condition.get("pattern", None)
if not pattern:
pattern_type = condition.get("pattern_type", None)
if pattern_type == "user_id":
pattern = user_id
elif pattern_type == "user_localpart":
pattern = UserID.from_string(user_id).localpart
if not pattern:
logger.warning("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition["key"] == "content.body":
body = self._event.content.get("body", None)
if not body or not isinstance(body, str):
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._value_cache.get(condition["key"], None)
if haystack is None:
return False
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name: Optional[str]) -> bool:
"""
Check an "event_match" push rule condition.
Args:
display_name: The display name, or None if there is not one.
Returns:
True if the display name is found in the event body, False otherwise.
"""
if not display_name:
return False
body = self._event.content.get("body", None)
if not body or not isinstance(body, str):
return False
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
r1 = re.escape(display_name)
r1 = to_word_pattern(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
return bool(r.search(body))
def _relation_match(self, condition: Mapping, user_id: str) -> bool:
"""
Check an "relation_match" push rule condition.
Args:
condition: The "event_match" push rule condition to match.
user_id: The user's MXID.
Returns:
True if the condition matches the event, False otherwise.
"""
rel_type = condition.get("rel_type")
if not rel_type:
logger.warning("relation_match condition missing rel_type")
return False
sender_pattern = condition.get("sender")
if sender_pattern is None:
sender_type = condition.get("sender_type")
if sender_type == "user_id":
sender_pattern = user_id
type_pattern = condition.get("type")
# If any other relations matches, return True.
for sender, event_type in self._relations.get(rel_type, ()):
if sender_pattern and not _glob_matches(sender_pattern, sender):
continue
if type_pattern and not _glob_matches(type_pattern, event_type):
continue
# All values must have matched.
return True
# No relations matched.
return False
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
50000, "regex_push_cache"
)
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
"""Tests if value matches glob.
Args:
glob
value: String to test against glob.
word_boundary: Whether to match against word boundaries or entire
string. Defaults to False.
"""
try:
r = regex_cache.get((glob, True, word_boundary), None)
if not r:
r = glob_to_regex(glob, word_boundary=word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
if prefix is None:
prefix = []
if result is None:
result = {}
for key, value in d.items():
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif isinstance(value, Mapping):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result

View File

@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type: Optional[str],
address: Optional[str],
shadow_banned: bool,
approved: bool,
) -> JsonDict:
"""
Args:
@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
or None for a normal user.
address: the IP address used to perform the regitration.
shadow_banned: Whether to shadow-ban the user
approved: Whether the user should be considered already approved by an
administrator.
"""
return {
"password_hash": password_hash,
@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"user_type": user_type,
"address": address,
"shadow_banned": shadow_banned,
"approved": approved,
}
async def _handle_request( # type: ignore[override]
@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type=content["user_type"],
address=content["address"],
shadow_banned=content["shadow_banned"],
approved=content["approved"],
)
return 200, {}

View File

@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
# If support for MSC3866 is not enabled, apply no filtering based on the
# `approved` column.
if self._msc3866_enabled:
approved = parse_boolean(request, "approved", default=True)
else:
approved = True
order_by = parse_string(
request,
"order_by",
@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
users, total = await self.store.get_users_paginate(
start, limit, user_id, name, guests, deactivated, order_by, direction
start,
limit,
user_id,
name,
guests,
deactivated,
order_by,
direction,
approved,
)
# If support for MSC3866 is not enabled, don't show the approval flag.
if not self._msc3866_enabled:
for user in users:
del user["approved"]
ret = {"users": users, "total": total}
if (start + limit) < total:
ret["next_token"] = str(start + len(users))
@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
self.deactivate_account_handler = hs.get_deactivate_account_handler()
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET(
self, request: SynapseRequest, user_id: str
@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
)
approved: Optional[bool] = None
if "approved" in body and self._msc3866_enabled:
approved = body["approved"]
if not isinstance(approved, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"'approved' parameter is not of type boolean",
)
# convert List[Dict[str, str]] into List[Tuple[str, str]]
if external_ids is not None:
new_external_ids = [
@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
if "user_type" in body:
await self.store.set_user_type(target_user, user_type)
if approved is not None:
await self.store.update_user_approval_status(target_user, approved)
user = await self.admin_handler.get_user(target_user)
assert user is not None
@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
if password is not None:
password_hash = await self.auth_handler.hash(password)
new_user_approved = True
if self._msc3866_enabled and approved is not None:
new_user_approved = approved
user_id = await self.registration_handler.register_user(
localpart=target_user.localpart,
password_hash=password_hash,
@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
default_display_name=displayname,
user_type=user_type,
by_admin=True,
approved=new_user_approved,
)
if threepids is not None:
@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
user_type=user_type,
default_display_name=displayname,
by_admin=True,
approved=True,
)
result = await register._create_registration_details(user_id, body)

View File

@ -28,7 +28,14 @@ from typing import (
from typing_extensions import TypedDict
from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
from synapse.api.constants import ApprovalNoticeMedium
from synapse.api.errors import (
Codes,
InvalidClientTokenError,
LoginError,
NotApprovedError,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
class LoginResponse(TypedDict, total=False):
user_id: str
access_token: str
access_token: Optional[str]
home_server: str
expires_in_ms: Optional[int]
refresh_token: Optional[str]
device_id: str
device_id: Optional[str]
well_known: Optional[Dict[str, Any]]
@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
hs.config.registration.refreshable_access_token_lifetime is not None
)
# Whether we need to check if the user has been approved or not.
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
if self._require_approval:
approved = await self.auth_handler.is_user_approved(result["user_id"])
if not approved:
raise NotApprovedError(
msg="This account is pending approval by a server administrator.",
approval_notice_medium=ApprovalNoticeMedium.NONE,
)
well_known_data = self._well_known_builder.get_well_known()
if well_known_data:
result["well_known"] = well_known_data
@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
if self._require_approval:
approved = await self.auth_handler.is_user_approved(user_id)
if not approved:
# If the user isn't approved (and needs to be) we won't allow them to
# actually log in, so we don't want to create a device/access token.
return LoginResponse(
user_id=user_id,
home_server=self.hs.hostname,
)
initial_display_name = login_submission.get("initial_device_display_name")
(
device_id,

View File

@ -47,7 +47,9 @@ class LoginTokenRequestServlet(RestServlet):
}
"""
PATTERNS = client_patterns("/login/token$")
PATTERNS = client_patterns(
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
super().__init__()

View File

@ -21,10 +21,15 @@ from twisted.web.server import Request
import synapse
import synapse.api.auth
import synapse.types
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.constants import (
APP_SERVICE_REGISTRATION_TYPE,
ApprovalNoticeMedium,
LoginType,
)
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
NotApprovedError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet):
hs.config.registration.inhibit_user_in_use_error
)
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
)
@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet):
access_token=return_dict.get("access_token"),
)
if self._require_approval:
raise NotApprovedError(
msg="This account needs to be approved by an administrator before it can be used.",
approval_notice_medium=ApprovalNoticeMedium.NONE,
)
return 200, return_dict
async def _do_appservice_registration(
@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
# We don't want to log the user in if we're going to deny them access because
# they need to be approved first.
if not params.get("inhibit_login", False) and not self._require_approval:
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
(

View File

@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache(
"get_rooms_for_user_with_stream_ordering", (user_id,)
)
self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View File

@ -423,16 +423,18 @@ class EventsPersistenceStorageController:
for d in ret_vals:
replaced_events.update(d)
events = []
persisted_events = []
for event, _ in events_and_contexts:
existing_event_id = replaced_events.get(event.event_id)
if existing_event_id:
events.append(await self.main_store.get_event(existing_event_id))
persisted_events.append(
await self.main_store.get_event(existing_event_id)
)
else:
events.append(event)
persisted_events.append(event)
return (
events,
persisted_events,
self.main_store.get_room_max_token(),
)

View File

@ -23,7 +23,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
@ -529,7 +529,18 @@ class StateStorageController:
)
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
with partial state.
"""
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@ -542,11 +553,11 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
async def get_current_hosts_in_room_or_partial_state_approximation(
self, room_id: str
) -> Sequence[str]:
) -> Collection[str]:
"""Get approximation of current hosts in room based on current state.
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
@ -566,14 +577,9 @@ class StateStorageController:
)
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
hosts_from_state_set = set(hosts_from_state)
# First take the list of hosts based on the current state.
# For rooms with partial state, this will be missing most hosts.
hosts = list(hosts_from_state)
# Then add in the list of hosts in the room at the time we joined.
# This will be an empty list for rooms with full state.
hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
hosts = set(hosts_at_join)
hosts.update(hosts_from_state)
return hosts

View File

@ -1141,17 +1141,57 @@ class DatabasePool:
desc: str = "simple_upsert",
lock: bool = True,
) -> bool:
"""
"""Insert a row with values + insertion_values; on conflict, update with values.
`lock` should generally be set to True (the default), but can be set
to False if either of the following are true:
1. there is a UNIQUE INDEX on the key columns. In this case a conflict
will cause an IntegrityError in which case this function will retry
the update.
2. we somehow know that we are the only thread which will be updating
this table.
As an additional note, this parameter only matters for old SQLite versions
because we will use native upserts otherwise.
All of our supported databases accept the nonstandard "upsert" statement in
their dialect of SQL. We call this a "native upsert". The syntax looks roughly
like:
INSERT INTO table VALUES (values + insertion_values)
ON CONFLICT (keyvalues)
DO UPDATE SET (values); -- overwrite `values` columns only
If (values) is empty, the resulting query is slighlty simpler:
INSERT INTO table VALUES (insertion_values)
ON CONFLICT (keyvalues)
DO NOTHING; -- do not overwrite any columns
This function is a helper to build such queries.
In order for upserts to make sense, the database must be able to determine when
an upsert CONFLICTs with an existing row. Postgres and SQLite ensure this by
requiring that a unique index exist on the column names used to detect a
conflict (i.e. `keyvalues.keys()`).
If there is no such index, we can "emulate" an upsert with a SELECT followed
by either an INSERT or an UPDATE. This is unsafe: we cannot make the same
atomicity guarantees that a native upsert can and are very vulnerable to races
and crashes. Therefore if we wish to upsert without an appropriate unique index,
we must either:
1. Acquire a table-level lock before the emulated upsert (`lock=True`), or
2. VERY CAREFULLY ensure that we are the only thread and worker which will be
writing to this table, in which case we can proceed without a lock
(`lock=False`).
Generally speaking, you should use `lock=True`. If the table in question has a
unique index[*], this class will use a native upsert (which is atomic and so can
ignore the `lock` argument). Otherwise this class will use an emulated upsert,
in which case we want the safer option unless we been VERY CAREFUL.
[*]: Some tables have unique indices added to them in the background. Those
tables `T` are keys in the dictionary UNIQUE_INDEX_BACKGROUND_UPDATES,
where `T` maps to the background update that adds a unique index to `T`.
This dictionary is maintained by hand.
At runtime, we constantly check to see if each of these background updates
has run. If so, we deem the coresponding table safe to upsert into, because
we can now use a native insert to do so. If not, we deem the table unsafe
to upsert into and require an emulated upsert.
Tables that do not appear in this dictionary are assumed to have an
appropriate unique index and therefore be safe to upsert into.
Args:
table: The table to upsert into

View File

@ -203,6 +203,7 @@ class DataStore(
deactivated: bool = False,
order_by: str = UserSortOrder.USER_ID.value,
direction: str = "f",
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@ -217,6 +218,7 @@ class DataStore(
deactivated: whether to include deactivated users
order_by: the sort order of the returned list
direction: sort ascending or descending
approved: whether to include approved users
Returns:
A tuple of a list of mappings from user to information and a count of total users.
"""
@ -249,6 +251,11 @@ class DataStore(
if not deactivated:
filters.append("deactivated = 0")
if not approved:
# We ignore NULL values for the approved flag because these should only
# be already existing users that we consider as already approved.
filters.append("approved IS FALSE")
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql_base = f"""
@ -262,7 +269,7 @@ class DataStore(
sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
displayname, avatar_url, creation_ts * 1000 as creation_ts
displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
{sql_base}
ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?

View File

@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,))
else:
raise Exception("Unknown events stream row type %s" % (row.type,))

View File

@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
"""Get all device list changes that happened in the room since the given
stream ID.
Returns:
Collection of user ID/device ID tuples of all devices that have
changed
"""
sql = """
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
WHERE room_id = ? AND stream_id > ?
"""
def get_device_list_changes_in_room_txn(
txn: LoggingTransaction,
) -> Collection[Tuple[str, str]]:
txn.execute(sql, (room_id, min_stream_id))
return cast(Collection[Tuple[str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_device_list_changes_in_room",
get_device_list_changes_in_room_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled.
Marks the associated row in `device_lists_changes_in_room` as handled,
if `stream_id` is provided.
"""
def add_device_list_outbound_pokes_txn(
@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context=context,
)
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if stream_id:
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
@ -1995,3 +2024,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
add_device_list_outbound_pokes_txn,
stream_ids,
)
async def add_remote_device_list_to_pending(
self, user_id: str, device_id: str
) -> None:
"""Add a device list update to the table tracking remote device list
updates during partial joins.
"""
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.simple_upsert(
table="device_lists_remote_pending",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={"stream_id": stream_id},
desc="add_remote_device_list_to_pending",
)
async def get_pending_remote_device_list_updates_for_room(
self, room_id: str
) -> Collection[Tuple[str, str]]:
"""Get the set of remote device list updates from the pending table for
the room.
"""
min_device_stream_id = await self.db_pool.simple_select_one_onecol(
table="partial_state_rooms",
keyvalues={
"room_id": room_id,
},
retcol="device_lists_stream_id",
desc="get_pending_remote_device_list_updates_for_room_device",
)
sql = """
SELECT user_id, device_id FROM device_lists_remote_pending AS d
INNER JOIN current_state_events AS c ON
type = 'm.room.member'
AND state_key = user_id
AND membership = 'join'
WHERE
room_id = ? AND stream_id > ?
"""
def get_pending_remote_device_list_updates_for_room_txn(
txn: LoggingTransaction,
) -> Collection[Tuple[str, str]]:
txn.execute(sql, (room_id, min_device_stream_id))
return cast(Collection[Tuple[str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_pending_remote_device_list_updates_for_room",
get_pending_remote_device_list_updates_for_room_txn,
)

View File

@ -73,13 +73,30 @@ pdus_pruned_from_federation_queue = Counter(
logger = logging.getLogger(__name__)
BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int(
datetime.timedelta(days=7).total_seconds()
)
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int(
datetime.timedelta(hours=1).total_seconds()
# Parameters controlling exponential backoff between backfill failures.
# After the first failure to backfill, we wait 2 hours before trying again. If the
# second attempt fails, we wait 4 hours before trying again. If the third attempt fails,
# we wait 8 hours before trying again, ... and so on.
#
# Each successive backoff period is twice as long as the last. However we cap this
# period at a maximum of 2^8 = 256 hours: a little over 10 days. (This is the smallest
# power of 2 which yields a maximum backoff period of at least 7 days---which was the
# original maximum backoff period.) Even when we hit this cap, we will continue to
# make backfill attempts once every 10 days.
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS = 8
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS = int(
datetime.timedelta(hours=1).total_seconds() * 1000
)
# We need a cap on the power of 2 or else the backoff period
# 2^N * (milliseconds per hour)
# will overflow when calcuated within the database. We ensure overflow does not occur
# by checking that the largest backoff period fits in a 32-bit signed integer.
_LONGEST_BACKOFF_PERIOD_MILLISECONDS = (
2**BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS
) * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1)
# All the info we need while iterating the DAG while backfilling
@attr.s(frozen=True, slots=True, auto_attribs=True)
@ -726,17 +743,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
async def get_backfill_points_in_room(
self,
room_id: str,
current_depth: int,
limit: int,
) -> List[Tuple[str, int]]:
"""
Gets the oldest events(backwards extremities) in the room along with the
approximate depth. Sorted by depth, highest to lowest (descending).
Get the backward extremities to backfill from in the room along with the
approximate depth.
Only returns events that are at a depth lower than or
equal to the `current_depth`. Sorted by depth, highest to lowest (descending)
so the closest events to the `current_depth` are first in the list.
We ignore extremities that are newer than the user's current scroll position
(ie, those with depth greater than `current_depth`) as:
1. we don't really care about getting events that have happened
after our current position; and
2. by the nature of paginating and scrolling back, we have likely
previously tried and failed to backfill from that extremity, so
to avoid getting "stuck" requesting the same backfill repeatedly
we drop those extremities.
Args:
room_id: Room where we want to find the oldest events
current_depth: The depth at the user's current scrollback position
limit: The max number of backfill points to return
Returns:
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
(descending)
(descending) so the closest events to the `current_depth` are first
in the list.
"""
def get_backfill_points_in_room_txn(
@ -749,7 +784,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# persisted in our database yet (meaning we don't know their depth
# specifically). So we need to look for the approximate depth from
# the events connected to the current backwards extremeties.
sql = """
if isinstance(self.database_engine, PostgresEngine):
least_function = "LEAST"
elif isinstance(self.database_engine, Sqlite3Engine):
least_function = "MIN"
else:
raise RuntimeError("Unknown database engine")
sql = f"""
SELECT backward_extrem.event_id, event.depth FROM events AS event
/**
* Get the edge connections from the event_edges table
@ -784,6 +827,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
* necessarily safe to assume that it will have been completed.
*/
AND edge.is_state is ? /* False */
/**
* We only want backwards extremities that are older than or at
* the same position of the given `current_depth` (where older
* means less than the given depth) because we're looking backwards
* from the `current_depth` when backfilling.
*
* current_depth (ignore events that come after this, ignore 2-4)
* |
*
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
*/
AND event.depth <= ? /* current_depth */
/**
* Exponential back-off (up to the upper bound) so we don't retry the
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
@ -795,31 +850,31 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
*/
AND (
failed_backfill_attempt_info.event_id IS NULL
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
* ? /* step */
)
)
/**
* Sort from highest to the lowest depth. Then tie-break on
* alphabetical order of the event_ids so we get a consistent
* ordering which is nice when asserting things in tests.
* Sort from highest (closest to the `current_depth`) to the lowest depth
* because the closest are most relevant to backfill from first.
* Then tie-break on alphabetical order of the event_ids so we get a
* consistent ordering which is nice when asserting things in tests.
*/
ORDER BY event.depth DESC, backward_extrem.event_id DESC
LIMIT ?
"""
if isinstance(self.database_engine, PostgresEngine):
least_function = "least"
elif isinstance(self.database_engine, Sqlite3Engine):
least_function = "min"
else:
raise RuntimeError("Unknown database engine")
txn.execute(
sql % (least_function,),
sql,
(
room_id,
False,
current_depth,
self._clock.time_msec(),
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
limit,
),
)
@ -835,24 +890,47 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
async def get_insertion_event_backward_extremities_in_room(
self,
room_id: str,
current_depth: int,
limit: int,
) -> List[Tuple[str, int]]:
"""
Get the insertion events we know about that we haven't backfilled yet
along with the approximate depth. Sorted by depth, highest to lowest
(descending).
along with the approximate depth. Only returns insertion events that are
at a depth lower than or equal to the `current_depth`. Sorted by depth,
highest to lowest (descending) so the closest events to the
`current_depth` are first in the list.
We ignore insertion events that are newer than the user's current scroll
position (ie, those with depth greater than `current_depth`) as:
1. we don't really care about getting events that have happened
after our current position; and
2. by the nature of paginating and scrolling back, we have likely
previously tried and failed to backfill from that insertion event, so
to avoid getting "stuck" requesting the same backfill repeatedly
we drop those insertion event.
Args:
room_id: Room where we want to find the oldest events
current_depth: The depth at the user's current scrollback position
limit: The max number of insertion event extremities to return
Returns:
List of (event_id, depth) tuples. Sorted by depth, highest to lowest
(descending)
(descending) so the closest events to the `current_depth` are first
in the list.
"""
def get_insertion_event_backward_extremities_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
sql = """
if isinstance(self.database_engine, PostgresEngine):
least_function = "LEAST"
elif isinstance(self.database_engine, Sqlite3Engine):
least_function = "MIN"
else:
raise RuntimeError("Unknown database engine")
sql = f"""
SELECT
insertion_event_extremity.event_id, event.depth
/* We only want insertion events that are also marked as backwards extremities */
@ -869,6 +947,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id
WHERE
insertion_event_extremity.room_id = ?
/**
* We only want extremities that are older than or at
* the same position of the given `current_depth` (where older
* means less than the given depth) because we're looking backwards
* from the `current_depth` when backfilling.
*
* current_depth (ignore events that come after this, ignore 2-4)
* |
*
* <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time>
*/
AND event.depth <= ? /* current_depth */
/**
* Exponential back-off (up to the upper bound) so we don't retry the
* same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc
@ -880,30 +970,30 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
*/
AND (
failed_backfill_attempt_info.event_id IS NULL
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */)
OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + (
(1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */))
* ? /* step */
)
)
/**
* Sort from highest to the lowest depth. Then tie-break on
* alphabetical order of the event_ids so we get a consistent
* ordering which is nice when asserting things in tests.
* Sort from highest (closest to the `current_depth`) to the lowest depth
* because the closest are most relevant to backfill from first.
* Then tie-break on alphabetical order of the event_ids so we get a
* consistent ordering which is nice when asserting things in tests.
*/
ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC
LIMIT ?
"""
if isinstance(self.database_engine, PostgresEngine):
least_function = "least"
elif isinstance(self.database_engine, Sqlite3Engine):
least_function = "min"
else:
raise RuntimeError("Unknown database engine")
txn.execute(
sql % (least_function,),
sql,
(
room_id,
current_depth,
self._clock.time_msec(),
1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS,
1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS,
limit,
),
)
return cast(List[Tuple[str, int]], txn.fetchall())

View File

@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
user_id,
room_id,
receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
if result:
@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
),
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
sql = f"""
@ -1074,7 +1068,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
limit,
),
)
rows = txn.fetchall()
rows = cast(List[Tuple[int, str, str, int]], txn.fetchall())
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
@ -1119,18 +1113,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# We always update `event_push_summary_last_receipt_stream_id` to
# ensure that we don't rescan the same receipts for remote users.
upper_limit = max_receipts_stream_id
receipts_last_processed_stream_id = max_receipts_stream_id
if len(rows) >= limit:
# If we pulled out a limited number of rows we only update the
# position to the last receipt we processed, so we continue
# processing the rest next iteration.
upper_limit = rows[-1][0]
receipts_last_processed_stream_id = rows[-1][0]
self.db_pool.simple_update_txn(
txn,
table="event_push_summary_last_receipt_stream_id",
keyvalues={},
updatevalues={"stream_id": upper_limit},
updatevalues={"stream_id": receipts_last_processed_stream_id},
)
return len(rows) < limit

View File

@ -2134,13 +2134,13 @@ class PersistEventsStore:
appear in events_and_context.
"""
# Only non outlier events will have push actions associated with them,
# Only notifiable events will have push actions associated with them,
# so let's filter them out. (This makes joining large rooms faster, as
# these queries took seconds to process all the state events).
non_outlier_events = [
notifiable_events = [
event
for event, _ in events_and_contexts
if not event.internal_metadata.is_outlier()
if event.internal_metadata.is_notifiable()
]
sql = """
@ -2153,7 +2153,7 @@ class PersistEventsStore:
WHERE event_id = ?
"""
if non_outlier_events:
if notifiable_events:
txn.execute_batch(
sql,
(
@ -2163,7 +2163,7 @@ class PersistEventsStore:
event.depth,
event.event_id,
)
for event in non_outlier_events
for event in notifiable_events
),
)

View File

@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]:
"""
Fetch the event ID for the latest receipt in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.
Returns:
The latest receipt, if one exists.
"""
result = await self.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.get_last_receipt_for_user_txn,
user_id,
room_id,
receipt_types,
)
if not result:
return None
event_id, _ = result
return event_id
def get_last_receipt_for_user_txn(
def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
user_id: str,
@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]:
"""
Fetch the event ID and stream_ordering for the latest receipt in a room
with one of the given receipt types.
Fetch the event ID and stream_ordering for the latest unthreaded receipt
in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch.
receipt_types: The receipt types to fetch.
Returns:
The event ID and stream ordering of the latest receipt, if one exists.
@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
WHERE {clause}
AND user_id = ?
AND room_id = ?
AND thread_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""

View File

@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"consent_version",
"consent_ts",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
"user_type",
"deactivated",
"shadow_banned",
],
allow_none=True,
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved
FROM users
WHERE name = ?
""",
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return None
return rows[0]
row = await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
for column in boolean_columns:
if not isinstance(row[column], bool):
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return res if res else False
@cached()
async def is_user_approved(self, user_id: str) -> bool:
"""Checks if a user is approved and therefore can be allowed to log in.
If the user's 'approved' column is NULL, we consider it as true given it means
the user was registered when support for an approval flow was either disabled
or nonexistent.
Args:
user_id: the user to check the approval status of.
Returns:
A boolean that is True if the user is approved, False otherwise.
"""
def is_user_approved_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
""",
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
return bool(rows[0]["approved"])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
func=is_user_approved_txn,
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean is turned into an int because the column is a smallint.
Args:
txn: the current database transaction.
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"approved": approved},
)
# Invalidate the caches of methods that read the value of the 'approved' flag.
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
# If support for MSC3866 is enabled and configured to require approval for new
# account, we will create new users with an 'approved' flag set to false.
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
async def add_access_token_to_user(
self,
user_id: str,
@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
approved: bool = False,
) -> None:
"""Attempts to register an account.
@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an
administrator.
Raises:
StoreError if the user_id could not be registered.
@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin,
user_type,
shadow_banned,
approved,
)
def _register_user(
@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
approved: bool,
) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
user_approved = approved or not self._require_approval
try:
if was_guest:
# Ensure that the guest user actually exists
@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
"approved": user_approved,
},
)
else:
@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
"approved": user_approved,
},
)
@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
async def update_user_approval_status(
self, user_id: UserID, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean will be turned into an int (in update_user_approval_status_txn)
because the column is a smallint.
Args:
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
await self.db_pool.runInteraction(
"update_user_approval_status",
self.update_user_approval_status_txn,
user_id.to_string(),
approved,
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""

View File

@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
# We now delete anything from `device_lists_remote_pending` with a
# stream ID less than the minimum
# `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
txn,
table="partial_state_rooms",
keyvalues={},
retcol="MIN(device_lists_stream_id)",
allow_none=True,
)
if device_lists_stream_id is None:
# There are no rooms being currently partially joined, so we delete everything.
txn.execute("DELETE FROM device_lists_remote_pending")
else:
sql = """
DELETE FROM device_lists_remote_pending
WHERE stream_id <= ?
"""
txn.execute(sql, (device_lists_stream_id,))
@cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@ -1236,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
"""Get the event ID of the initial join that started the partial
join, and the device list stream ID at the point we started the partial
join.
"""
result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
)
return result["join_event_id"], result["device_lists_stream_id"]
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"

View File

@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
Callable,
Collection,
Dict,
FrozenSet,
@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -148,42 +146,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
"""
Returns a list of users in the room sorted by longest in the room first
(aka. with the lowest depth). This is done to match the sort in
`get_current_hosts_in_room()` and so we can re-use the cache but it's
not horrible to have here either.
Uses `m.room.member`s in the room state at the current forward extremities to
determine which users are in the room.
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a member is or
is not in the room when the opposite is true.
"""
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
return await self.db_pool.simple_select_onecol(
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"room_id": room_id,
"membership": Membership.JOIN,
},
retcol="state_key",
desc="get_users_in_room",
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
"""
Returns a list of users in the room sorted by longest in the room first
(aka. with the lowest depth). This is done to match the sort in
`get_current_hosts_in_room()` and so we can re-use the cache but it's
not horrible to have here either.
"""
sql = """
SELECT c.state_key FROM current_state_events as c
/* Get the depth of the event from the events table */
INNER JOIN events AS e USING (event_id)
WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
/* Sorted by lowest depth first */
ORDER BY e.depth ASC;
"""
"""Returns a list of users in the room."""
txn.execute(sql, (room_id, Membership.JOIN))
return [r[0] for r in txn]
return self.db_pool.simple_select_onecol_txn(
txn,
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"room_id": room_id,
"membership": Membership.JOIN,
},
retcol="state_key",
)
@cached()
def get_user_in_room_with_profile(
@ -600,58 +593,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn
)
@cachedList(
cached_method_name="get_rooms_for_user_with_stream_ordering",
list_name="user_ids",
)
async def get_rooms_for_users_with_stream_ordering(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
"""A batched version of `get_rooms_for_user_with_stream_ordering`.
Returns:
Map from user_id to set of rooms that is currently in.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_users_with_stream_ordering",
self._get_rooms_for_users_with_stream_ordering_txn,
user_ids,
)
def _get_rooms_for_users_with_stream_ordering_txn(
self, txn: LoggingTransaction, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
self.database_engine,
"c.state_key",
user_ids,
)
sql = f"""
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND c.membership = ?
AND {clause}
"""
txn.execute(sql, [Membership.JOIN] + args)
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
user_id: set() for user_id in user_ids
}
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
room_id, PersistedEventPosition(instance, stream_id)
)
)
return {user_id: frozenset(v) for user_id, v in result.items()}
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
) -> Set[str]:
@ -693,19 +634,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0] for row in txn}
@cancellable
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
(user_id,),
None,
update_metrics=False,
)
return frozenset(r.room_id for r in rooms)
if rooms:
return frozenset(r.room_id for r in rooms)
room_ids = await self.db_pool.simple_select_onecol(
table="current_state_events",
keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
"state_key": user_id,
},
retcol="room_id",
desc="get_rooms_for_user",
)
return frozenset(room_ids)
@cachedList(
cached_method_name="get_rooms_for_user",
list_name="user_ids",
)
async def get_rooms_for_users(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
Returns:
Map from user_id to set of rooms that is currently in.
"""
rows = await self.db_pool.simple_select_many_batch(
table="current_state_events",
column="state_key",
iterable=user_ids,
retcols=(
"state_key",
"room_id",
),
keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
},
desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
for row in rows:
user_rooms[row["state_key"]].add(row["room_id"])
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
@ -936,7 +926,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
"""
Get current hosts in room based on current state.
@ -944,48 +971,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
longest is good because they're most likely to have anything we ask
about.
Uses `m.room.member`s in the room state at the current forward extremities to
determine which hosts are in the room.
For SQLite the returned list is not ordered, as SQLite doesn't support
the appropriate SQL.
Will return inaccurate results for rooms with partial state, since the state for
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a host is or
is not in the room when the opposite is true.
Uses `m.room.member`s in the room state at the current forward
extremities to determine which hosts are in the room.
Will return inaccurate results for rooms with partial state, since the
state for the forward extremities of those rooms will exclude most
members. We may also calculate room state incorrectly for such rooms and
believe that a host is or is not in the room when the opposite is true.
Returns:
Returns a list of servers sorted by longest in the room first. (aka.
sorted by join with the lowest depth first).
"""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is None and isinstance(self.database_engine, Sqlite3Engine):
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
if users is not None:
# Because `users` is sorted from lowest -> highest depth, the list
# of domains will also be sorted that way.
domains: List[str] = []
# We use a `Set` just for fast lookups
domain_set: Set[str] = set()
for u in users:
if ":" not in u:
continue
domain = get_domain_from_id(u)
if domain not in domain_set:
domain_set.add(domain)
domains.append(domain)
return domains
domains = await self.get_current_hosts_in_room(room_id)
return list(domains)
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
# Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the
@ -1013,7 +1025,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [d for d, in txn if d is not None]
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)
async def get_joined_hosts(

View File

@ -0,0 +1,20 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to the users table to track whether the user needs to be approved by an
-- administrator.
-- A NULL column means the user was created before this feature was supported by Synapse,
-- and should be considered as TRUE.
ALTER TABLE users ADD COLUMN approved BOOLEAN;

View File

@ -0,0 +1,28 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores remote device lists we have received for remote users while a partial
-- join is in progress.
--
-- This allows us to replay any device list updates if it turns out the remote
-- user was in the partially joined room
CREATE TABLE device_lists_remote_pending(
stream_id BIGINT PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
-- We only keep the most recent update for a given user/device pair.
CREATE UNIQUE INDEX device_lists_remote_pending_user_device_id ON device_lists_remote_pending(user_id, device_id);

View File

@ -66,6 +66,21 @@ def _is_dev_dependency(req: Requirement) -> bool:
)
def _should_ignore_runtime_requirement(req: Requirement) -> bool:
# This is a build-time dependency. Irritatingly, `poetry build` ignores the
# requirements listed in the [build-system] section of pyproject.toml, so in order
# to support `poetry install --no-dev` we have to mark it as a runtime dependency.
# See discussion on https://github.com/python-poetry/poetry/issues/6154 (it sounds
# like the poetry authors don't consider this a bug?)
#
# In any case, workaround this by ignoring setuptools_rust here. (It might be
# slightly cleaner to put `setuptools_rust` in a `build` extra or similar, but for
# now let's do something quick and dirty.
if req.name == "setuptools_rust":
return True
return False
class Dependency(NamedTuple):
requirement: Requirement
must_be_installed: bool
@ -77,7 +92,7 @@ def _generic_dependencies() -> Iterable[Dependency]:
assert requirements is not None
for raw_requirement in requirements:
req = Requirement(raw_requirement)
if _is_dev_dependency(req):
if _is_dev_dependency(req) or _should_ignore_runtime_requirement(req):
continue
# https://packaging.pypa.io/en/latest/markers.html#usage notes that

View File

@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store.get_rooms_for_user.invalidate_all()
self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()

View File

@ -19,16 +19,18 @@ import frozendict
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.room_versions import RoomVersions
from synapse.appservice import ApplicationService
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.push.bulk_push_rule_evaluator import _flatten_dict
from synapse.push.httppusher import tweaks_for_actions
from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict
from synapse.synapse_rust.push import PushRuleEvaluator
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@ -41,7 +43,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
content: JsonDict,
relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
relations_match_enabled: bool = False,
) -> PushRuleEvaluatorForEvent:
) -> PushRuleEvaluator:
event = FrozenEvent(
{
"event_id": "$event_id",
@ -56,12 +58,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count = 0
sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event,
return PushRuleEvaluator(
_flatten_dict(event),
room_member_count,
sender_power_level,
power_levels,
relations or set(),
power_levels.get("notifications", {}),
relations or {},
relations_match_enabled,
)
@ -293,7 +295,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
]
self.assertEqual(
push_rule_evaluator.tweaks_for_actions(actions),
tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)
@ -304,9 +306,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator(
{}, {"m.annotation": {("@user:test", "m.reaction")}}
)
condition = {"kind": "relation_match"}
# Oddly, an unknown condition always matches.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# A push rule evaluator with the experimental rule enabled.
evaluator = self._get_evaluator(
@ -439,3 +438,80 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
)
self.assertEqual(len(users_with_push_actions), 0)
class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.main_store = homeserver.get_datastores().main
self.user_id1 = self.register_user("user1", "password")
self.tok1 = self.login(self.user_id1, "password")
self.user_id2 = self.register_user("user2", "password")
self.tok2 = self.login(self.user_id2, "password")
self.room_id = self.helper.create_room_as(tok=self.tok1)
# We want to test history visibility works correctly.
self.helper.send_state(
self.room_id,
EventTypes.RoomHistoryVisibility,
{"history_visibility": HistoryVisibility.JOINED},
tok=self.tok1,
)
def get_notif_count(self, user_id: str) -> int:
return self.get_success(
self.main_store.db_pool.simple_select_one_onecol(
table="event_push_actions",
keyvalues={"user_id": user_id},
retcol="COALESCE(SUM(notif), 0)",
desc="get_staging_notif_count",
)
)
def test_plain_message(self) -> None:
"""Test that sending a normal message in a room will trigger a
notification
"""
# Have user2 join the room and cle
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
# They start off with no notifications, but get them when messages are
# sent.
self.assertEqual(self.get_notif_count(self.user_id2), 0)
user1 = UserID.from_string(self.user_id1)
self.create_and_send_event(self.room_id, user1)
self.assertEqual(self.get_notif_count(self.user_id2), 1)
def test_delayed_message(self) -> None:
"""Test that a delayed message that was from before a user joined
doesn't cause a notification for the joined user.
"""
user1 = UserID.from_string(self.user_id1)
# Send a message before user2 joins
event_id1 = self.create_and_send_event(self.room_id, user1)
# Have user2 join the room
self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
# They start off with no notifications
self.assertEqual(self.get_notif_count(self.user_id2), 0)
# Send another message that references the event before the join to
# simulate a "delayed" event
self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
# user2 should not be notified about it, because they can't see it.
self.assertEqual(self.get_notif_count(self.user_id2), 0)

View File

@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import devices, login, logout, profile, room, sync
from synapse.rest.client import devices, login, logout, profile, register, room, sync
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid approved
channel = self.make_request(
"GET",
self.url + "?approved=not_bool",
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
channel = self.make_request(
"GET",
@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_filter_out_approved(self) -> None:
"""Tests that the endpoint can filter out approved users."""
# Create our users.
self._create_users(2)
# Get the list of users.
channel = self.make_request(
"GET",
self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
# Exclude the admin, because we don't want to accidentally un-approve the admin.
non_admin_user_ids = [
user["name"]
for user in channel.json_body["users"]
if user["name"] != self.admin_user
]
self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
# Select a user and un-approve them. We do this rather than the other way around
# because, since these users are created by an admin, we consider them already
# approved.
not_approved_user = non_admin_user_ids[0]
channel = self.make_request(
"PUT",
f"/_synapse/admin/v2/users/{not_approved_user}",
{"approved": False},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
# Now get the list of users again, this time filtering out approved users.
channel = self.make_request(
"GET",
self.url + "?approved=false",
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
non_admin_user_ids = [
user["name"]
for user in channel.json_body["users"]
if user["name"] != self.admin_user
]
# We should only have our unapproved user now.
self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
self.assertEqual(not_approved_user, non_admin_user_ids[0])
def _order_test(
self,
expected_user_list: List[str],
@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
register.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_approve_account(self) -> None:
"""Tests that approving an account correctly sets the approved flag for the user."""
url = self.url_prefix % "@bob:test"
# Create the user using the client-server API since otherwise the user will be
# marked as approved automatically.
channel = self.make_request(
"POST",
"register",
{
"username": "bob",
"password": "test",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
# Get user
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(False, channel.json_body["approved"])
# Approve user
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content={"approved": True},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(True, channel.json_body["approved"])
# Check that the user is now approved
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(True, channel.json_body["approved"])
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_register_approved(self) -> None:
url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content={"password": "abc123", "approved": True},
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(1, channel.json_body["approved"])
# Get user
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(1, channel.json_body["approved"])
def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)

View File

@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase):
body={"auth": {"session": session_id}},
)
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config(
{
"oidc_config": TEST_OIDC_CONFIG,
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
},
}
)
def test_sso_not_approved(self) -> None:
"""Tests that if we register a user via SSO while requiring approval for new
accounts, we still raise the correct error before logging the user in.
"""
login_resp = self.helper.login_via_oidc("username", expected_status=403)
self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
self.assertEqual(
ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
)
# Check that we didn't register a device for the user during the login attempt.
devices = self.get_success(
self.hs.get_datastores().main.get_devices_by_user("@username:test")
)
self.assertEqual(len(devices), 0)
class RefreshAuthTests(unittest.HomeserverTestCase):
servlets = [

View File

@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
register.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_require_approval(self) -> None:
channel = self.make_request(
"POST",
"register",
{
"username": "kermit",
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
params = {
"type": LoginType.PASSWORD,
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request("POST", LOGIN_URL, params)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):

View File

@ -22,6 +22,8 @@ from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@ -45,18 +47,18 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.password = "password"
def test_disabled(self) -> None:
channel = self.make_request("POST", "/login/token", {}, access_token=None)
channel = self.make_request("POST", endpoint, {}, access_token=None)
self.assertEqual(channel.code, 400)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 400)
@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_require_auth(self) -> None:
channel = self.make_request("POST", "/login/token", {}, access_token=None)
channel = self.make_request("POST", endpoint, {}, access_token=None)
self.assertEqual(channel.code, 401)
@override_config({"experimental_features": {"msc3882_enabled": True}})
@ -64,7 +66,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 401)
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
@ -79,7 +81,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
},
}
channel = self.make_request("POST", "/login/token", uia, access_token=token)
channel = self.make_request("POST", endpoint, uia, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
@ -100,7 +102,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
@ -127,6 +129,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)

View File

@ -22,7 +22,11 @@ import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.constants import (
APP_SERVICE_REGISTRATION_TYPE,
ApprovalNoticeMedium,
LoginType,
)
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_require_approval(self) -> None:
channel = self.make_request(
"POST",
"register",
{
"username": "kermit",
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
class AccountValidityTestCase(unittest.HomeserverTestCase):

View File

@ -543,8 +543,12 @@ class RestHelper:
return channel.json_body
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
"""Log in (as a new user) via OIDC
def login_via_oidc(
self,
remote_user_id: str,
expected_status: int = 200,
) -> JsonDict:
"""Log in via OIDC
Returns the result of the final token login.
@ -578,7 +582,9 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
assert channel.code == HTTPStatus.OK
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
return channel.json_body
def auth_via_oidc(

View File

@ -754,18 +754,28 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room(self):
"""
Test to make sure we get some backfill points
Test to make sure only backfill points that are older and come before
the `current_depth` are returned.
"""
setup_info = self._setup_room_for_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Try at "B"
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id)
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"])
# Try at "A"
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
# Event "2" has a depth of 2 but is not included here because we only
# know the approximate depth of 5 from our event "3".
self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"])
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self,
@ -776,6 +786,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"""
setup_info = self._setup_room_for_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Record some attempts to backfill these events which will make
# `get_backfill_points_in_room` exclude them because we
@ -795,12 +806,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# No time has passed since we attempted to backfill ^
# Try at "B"
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id)
self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
# Only the backfill points that we didn't record earlier exist here.
self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"])
self.assertEqual(backfill_event_ids, ["b6", "2", "b1"])
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self,
@ -812,6 +824,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"""
setup_info = self._setup_room_for_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Record some attempts to backfill these events which will make
# `get_backfill_points_in_room` exclude them because we
@ -839,27 +852,66 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# visible regardless.
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
# Make sure that "b1" is not in the list because we've
# Try at "A" and make sure that "b1" is not in the list because we've
# already attempted many times
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id)
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"])
self.assertEqual(backfill_event_ids, ["b3", "b2"])
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
# see if we can now backfill it
self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
# Try again after we advanced enough time and we should see "b3" again
# Try at "A" again after we advanced enough time and we should see "b3" again
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id)
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(
backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow(
self,
) -> None:
"""
A test that reproduces #13929 (Postgres only).
Test to make sure we can still get backfill points after many failed pull
attempts that cause us to backoff to the limit. Even if the backoff formula
would tell us to wait for more seconds than can be expressed in a 32 bit
signed int.
"""
setup_info = self._setup_room_for_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Pretend that we have tried and failed 10 times to backfill event b1.
for _ in range(10):
self.get_success(
self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
)
# If the backoff periods grow without limit:
# After the first failed attempt, we would have backed off for 1 << 1 = 2 hours.
# After the second failed attempt we would have backed off for 1 << 2 = 4 hours,
# so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours.
# Wait 1100 hours just so we have a nice round number.
self.reactor.advance(datetime.timedelta(hours=1100).total_seconds())
# 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit
# signed integer. The bug we're reproducing is that this overflow causes an
# error in postgres preventing us from fetching a set of backwards extremities
# to retry fetching.
backfill_points = self.get_success(
self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100)
)
# We should aim to fetch all backoff points: b1's latest backoff period has
# expired, and we haven't tried the rest.
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
"""
Sets up a room with various insertion event backward extremities to test
@ -938,18 +990,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room(self):
"""
Test to make sure insertion event backward extremities are returned.
Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned.
"""
setup_info = self._setup_room_for_insertion_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Try at "insertion_eventB"
backfill_points = self.get_success(
self.store.get_insertion_event_backward_extremities_in_room(room_id)
self.store.get_insertion_event_backward_extremities_in_room(
room_id, depth_map["insertion_eventB"], limit=100
)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
# Try at "insertion_eventA"
backfill_points = self.get_success(
self.store.get_insertion_event_backward_extremities_in_room(
room_id, depth_map["insertion_eventA"], limit=100
)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
# Event "2" has a depth of 2 but is not included here because we only
# know the approximate depth of 5 from our event "3".
self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self,
@ -961,6 +1027,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"""
setup_info = self._setup_room_for_insertion_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Record some attempts to backfill these events which will make
# `get_insertion_event_backward_extremities_in_room` exclude them
@ -973,12 +1040,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# No time has passed since we attempted to backfill ^
# Try at "insertion_eventB"
backfill_points = self.get_success(
self.store.get_insertion_event_backward_extremities_in_room(room_id)
self.store.get_insertion_event_backward_extremities_in_room(
room_id, depth_map["insertion_eventB"], limit=100
)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
# Only the backfill points that we didn't record earlier exist here.
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
self.assertEqual(backfill_event_ids, ["insertion_eventB"])
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self,
@ -991,6 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"""
setup_info = self._setup_room_for_insertion_backfill_tests()
room_id = setup_info.room_id
depth_map = setup_info.depth_map
# Record some attempts to backfill these events which will make
# `get_backfill_points_in_room` exclude them because we
@ -1027,13 +1098,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# because we haven't waited long enough for this many attempts.
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
# Make sure that "insertion_eventA" is not in the list because we've
# already attempted many times
# Try at "insertion_eventA" and make sure that "insertion_eventA" is not
# in the list because we've already attempted many times
backfill_points = self.get_success(
self.store.get_insertion_event_backward_extremities_in_room(room_id)
self.store.get_insertion_event_backward_extremities_in_room(
room_id, depth_map["insertion_eventA"], limit=100
)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
self.assertEqual(backfill_event_ids, [])
# Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
# see if we can now backfill it
@ -1042,12 +1115,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Try at "insertion_eventA" again after we advanced enough time and we
# should see "insertion_eventA" again
backfill_points = self.get_success(
self.store.get_insertion_event_backward_extremities_in_room(room_id)
self.store.get_insertion_event_backward_extremities_in_room(
room_id, depth_map["insertion_eventA"], limit=100
)
)
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertListEqual(
backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
)
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
@attr.s

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Collection, Optional
from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester
@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase):
)
)
def get_last_unthreaded_receipt(
self, receipt_types: Collection[str], room_id: Optional[str] = None
) -> Optional[str]:
"""
Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
Args:
receipt_types: The receipt types to fetch.
Returns:
The latest receipt, if one exists.
"""
result = self.get_success(
self.store.db_pool.runInteraction(
"get_last_receipt_event_id_for_user",
self.store.get_last_unthreaded_receipt_for_user_txn,
OUR_USER_ID,
room_id or self.room_id1,
receipt_types,
)
)
if not result:
return None
event_id, _ = result
return event_id
def test_return_empty_with_no_data(self) -> None:
res = self.get_success(
self.store.get_receipts_for_user(
@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, None)
def test_get_receipts_for_user(self) -> None:
@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase):
)
# Test we get the latest event when we want both private and public receipts
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
self.assertEqual(res, event1_2_id)
# Test we get the older event when we want only public receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_1_id)
# Test we get the latest event when we want only the private receipt
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
self.assertEqual(res, event1_2_id)
# Test receipt updating
@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase):
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
)
)
res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
self.assertEqual(res, event1_2_id)
# Send some events into the second room
@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase):
{},
)
)
res = self.get_success(
self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID,
self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
)
res = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
)
self.assertEqual(res, event2_1_id)

View File

@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.unittest import HomeserverTestCase, override_config
class RegistrationStoreTestCase(HomeserverTestCase):
@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
"approved": 1,
},
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
ThreepidValidationError,
)
self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
def default_config(self) -> JsonDict:
config = super().default_config()
# If there's already some config for this feature in the default config, it
# means we're overriding it with @override_config. In this case we don't want
# to do anything more with it.
msc3866_config = config.get("experimental_features", {}).get("msc3866")
if msc3866_config is not None:
return config
# Require approval for all new accounts.
config["experimental_features"] = {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = "@my-user:test"
self.pwhash = "{xx1}123456789"
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": False,
}
}
}
)
def test_approval_not_required(self) -> None:
"""Tests that if we don't require approval for new accounts, newly created
accounts are automatically marked as approved.
"""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertTrue(user["approved"])
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
def test_approval_required(self) -> None:
"""Tests that if we require approval for new accounts, newly created accounts
are not automatically marked as approved.
"""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertFalse(user["approved"])
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
def test_override(self) -> None:
"""Tests that if we require approval for new accounts, but we explicitly say the
new user should be considered approved, they're marked as approved.
"""
self.get_success(
self.store.register_user(
self.user_id,
self.pwhash,
approved=True,
)
)
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
self.assertEqual(user["approved"], 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
def test_approve_user(self) -> None:
"""Tests that approving the user updates their approval status."""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
self.get_success(
self.store.update_user_approval_status(
UserID.from_string(self.user_id), True
)
)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)

View File

@ -40,7 +40,10 @@ class TestDependencyChecker(TestCase):
def mock_installed_package(
self, distribution: Optional[DummyDistribution]
) -> Generator[None, None, None]:
"""Pretend that looking up any distribution yields the given `distribution`."""
"""Pretend that looking up any package yields the given `distribution`.
If `distribution = None`, we pretend that the package is not installed.
"""
def mock_distribution(name: str):
if distribution is None:
@ -81,7 +84,7 @@ class TestDependencyChecker(TestCase):
self.assertRaises(DependencyException, check_requirements)
def test_checks_ignore_dev_dependencies(self) -> None:
"""Bot generic and per-extra checks should ignore dev dependencies."""
"""Both generic and per-extra checks should ignore dev dependencies."""
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["dummypkg >= 1; extra == 'mypy'"],
@ -142,3 +145,16 @@ class TestDependencyChecker(TestCase):
with self.mock_installed_package(new_release_candidate):
# should not raise
check_requirements()
def test_setuptools_rust_ignored(self) -> None:
"""Test a workaround for a `poetry build` problem. Reproduces #13926."""
with patch(
"synapse.util.check_dependencies.metadata.requires",
return_value=["setuptools_rust >= 1.3"],
):
with self.mock_installed_package(None):
# should not raise, even if setuptools_rust is not installed
check_requirements()
with self.mock_installed_package(old):
# We also ignore old versions of setuptools_rust
check_requirements()