Merge branch 'release-v1.24.0' of github.com:matrix-org/synapse into matrix-org-hotfixes
commit
16744644f6
|
@ -6,7 +6,7 @@
|
|||
set -ex
|
||||
|
||||
apt-get update
|
||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
|
||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
|
||||
|
||||
export LANG="C.UTF-8"
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Simplify the way the `HomeServer` object caches its internal attributes.
|
|
@ -0,0 +1 @@
|
|||
Allow per-room profiles to be used for the server notice user.
|
|
@ -0,0 +1 @@
|
|||
Add additional error checking for OpenID Connect and SAML mapping providers.
|
|
@ -0,0 +1 @@
|
|||
Allow Date header through CORS. Contributed by Nicolas Chamo.
|
|
@ -0,0 +1 @@
|
|||
Remove unnecessary function arguments and add typing to several membership replication classes.
|
|
@ -0,0 +1 @@
|
|||
Add tests for `password_auth_provider`s.
|
|
@ -0,0 +1 @@
|
|||
Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages".
|
|
@ -0,0 +1 @@
|
|||
Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir.
|
|
@ -0,0 +1 @@
|
|||
Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled.
|
|
@ -0,0 +1 @@
|
|||
Add `force_purge` option to delete-room admin api.
|
|
@ -0,0 +1 @@
|
|||
Drop redundant database index on `event_json`.
|
|
@ -0,0 +1 @@
|
|||
Simplify `uk.half-shot.msc2778.login.application_service` login handler.
|
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication.
|
|
@ -0,0 +1 @@
|
|||
Refactor `password_auth_provider` support code.
|
|
@ -0,0 +1 @@
|
|||
Add missing `ordering` to background database updates.
|
|
@ -0,0 +1 @@
|
|||
Simplify the way the `HomeServer` object caches its internal attributes.
|
|
@ -0,0 +1 @@
|
|||
Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.
|
|
@ -0,0 +1 @@
|
|||
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
|
|
@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected.
|
|||
|
||||
The API is:
|
||||
|
||||
```json
|
||||
```
|
||||
POST /_synapse/admin/v1/rooms/<room_id>/delete
|
||||
```
|
||||
|
||||
|
@ -439,6 +439,10 @@ The following JSON body parameters are available:
|
|||
future attempts to join the room. Defaults to `false`.
|
||||
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
|
||||
Defaults to `true`.
|
||||
* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it
|
||||
will force a purge to go ahead even if there are local users still in the room. Do not
|
||||
use this unless a regular `purge` operation fails, as it could leave those users'
|
||||
clients in a confused state.
|
||||
|
||||
The JSON body must not be empty. The body must be at least `{}`.
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
|
|||
|
||||
It should perform any appropriate sanity checks on the provided
|
||||
configuration, and return an object which is then passed into
|
||||
`__init__`.
|
||||
|
||||
This method should have the `@staticmethod` decoration.
|
||||
|
||||
|
|
|
@ -2271,6 +2271,16 @@ push:
|
|||
#
|
||||
#include_content: false
|
||||
|
||||
# When a push notification is received, an unread count is also sent.
|
||||
# This number can either be calculated as the number of unread messages
|
||||
# for the user, or the number of *rooms* the user has unread messages in.
|
||||
#
|
||||
# The default value is "true", meaning push clients will see the number of
|
||||
# rooms with unread messages in them. Uncomment to instead send the number
|
||||
# of unread messages.
|
||||
#
|
||||
#group_unread_count_by_room: false
|
||||
|
||||
|
||||
# Spam checkers are third-party modules that can block specific actions
|
||||
# of local users, such as creating rooms and registering undesirable
|
||||
|
|
1
mypy.ini
1
mypy.ini
|
@ -80,6 +80,7 @@ files =
|
|||
synapse/util/metrics.py,
|
||||
tests/replication,
|
||||
tests/test_utils,
|
||||
tests/handlers/test_password_providers.py,
|
||||
tests/rest/client/v2_alpha/test_auth.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
|
|
|
@ -23,6 +23,9 @@ class PushConfig(Config):
|
|||
def read_config(self, config, **kwargs):
|
||||
push_config = config.get("push") or {}
|
||||
self.push_include_content = push_config.get("include_content", True)
|
||||
self.push_group_unread_count_by_room = push_config.get(
|
||||
"group_unread_count_by_room", True
|
||||
)
|
||||
|
||||
pusher_instances = config.get("pusher_instances") or []
|
||||
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
|
||||
|
@ -68,4 +71,14 @@ class PushConfig(Config):
|
|||
# include the event ID and room ID in push notification payloads.
|
||||
#
|
||||
#include_content: false
|
||||
|
||||
# When a push notification is received, an unread count is also sent.
|
||||
# This number can either be calculated as the number of unread messages
|
||||
# for the user, or the number of *rooms* the user has unread messages in.
|
||||
#
|
||||
# The default value is "true", meaning push clients will see the number of
|
||||
# rooms with unread messages in them. Uncomment to instead send the number
|
||||
# of unread messages.
|
||||
#
|
||||
#group_unread_count_by_room: false
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +26,7 @@ from typing import (
|
|||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
|
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
|
|||
# better way to break the loop
|
||||
account_handler = ModuleApi(hs, self)
|
||||
|
||||
self.password_providers = []
|
||||
for module, config in hs.config.password_providers:
|
||||
try:
|
||||
self.password_providers.append(
|
||||
module(config=config, account_handler=account_handler)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
self.password_providers = [
|
||||
PasswordProvider.load(module, config, account_handler)
|
||||
for module, config in hs.config.password_providers
|
||||
]
|
||||
|
||||
logger.info("Extra password_providers: %r", self.password_providers)
|
||||
logger.info("Extra password_providers: %s", self.password_providers)
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
|
|||
# type in the list. (NB that the spec doesn't require us to do so and
|
||||
# clients which favour types that they don't understand over those that
|
||||
# they do are technically broken)
|
||||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = []
|
||||
if self._password_enabled:
|
||||
if hs.config.password_localdb_enabled:
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "get_supported_login_types"):
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t not in login_types:
|
||||
login_types.append(t)
|
||||
|
||||
if not self._password_enabled:
|
||||
login_types.remove(LoginType.PASSWORD)
|
||||
|
||||
self._supported_login_types = login_types
|
||||
|
||||
# Login types and UI Auth types have a heavy overlap, but are not
|
||||
# necessarily identical. Login types have SSO (and other login types)
|
||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||
|
@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
|
|||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
)
|
||||
|
||||
# Ratelimitier for failed /login attempts
|
||||
self._failed_login_attempts_ratelimiter = Ratelimiter(
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
)
|
||||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Expire old UI auth sessions after a period of time.
|
||||
|
@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
|
|||
res = await checker.check_auth(authdict, clientip=clientip)
|
||||
return res
|
||||
|
||||
# build a v1-login-style dict out of the authdict and fall back to the
|
||||
# v1 code
|
||||
user_id = authdict.get("user")
|
||||
|
||||
if user_id is None:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
(canonical_id, callback) = await self.validate_login(user_id, authdict)
|
||||
# fall back to the v1 login flow
|
||||
canonical_id, _ = await self.validate_login(authdict)
|
||||
return canonical_id
|
||||
|
||||
def _get_params_recaptcha(self) -> dict:
|
||||
|
@ -824,15 +830,157 @@ class AuthHandler(BaseHandler):
|
|||
return self._supported_login_types
|
||||
|
||||
async def validate_login(
|
||||
self, username: str, login_submission: Dict[str, Any]
|
||||
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate
|
||||
m.login.password auth types.
|
||||
Also used by the user-interactive auth flow to validate auth types which don't
|
||||
have an explicit UIA handler, including m.password.auth.
|
||||
|
||||
Args:
|
||||
username: username supplied by the user
|
||||
login_submission: the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
ratelimit: whether to apply the failed_login_attempt ratelimiter
|
||||
Returns:
|
||||
A tuple of the canonical user id, and optional callback
|
||||
to be called once the access token and device id are issued
|
||||
Raises:
|
||||
StoreError if there was a problem accessing the database
|
||||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
login_type = login_submission.get("type")
|
||||
if not isinstance(login_type, str):
|
||||
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
|
||||
|
||||
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
||||
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
||||
#
|
||||
# But the auth providers' check_auth interface requires a username, so in
|
||||
# practice we can only support login methods which we can map to a username
|
||||
# anyway.
|
||||
|
||||
# special case to check for "password" for the check_password interface
|
||||
# for the auth providers
|
||||
password = login_submission.get("password")
|
||||
if login_type == LoginType.PASSWORD:
|
||||
if not self._password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
if not isinstance(password, str):
|
||||
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
|
||||
|
||||
# map old-school login fields into new-school "identifier" fields.
|
||||
identifier_dict = convert_client_dict_legacy_fields_to_identifier(
|
||||
login_submission
|
||||
)
|
||||
|
||||
# convert phone type identifiers to generic threepids
|
||||
if identifier_dict["type"] == "m.id.phone":
|
||||
identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier_dict["type"] == "m.id.thirdparty":
|
||||
address = identifier_dict.get("address")
|
||||
medium = identifier_dict.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
if medium == "email":
|
||||
try:
|
||||
address = canonicalise_email(address)
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
# We also apply account rate limiting using the 3PID as a key, as
|
||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
(medium, address), update=False
|
||||
)
|
||||
|
||||
# Check for login providers that support 3pid login types
|
||||
if login_type == LoginType.PASSWORD:
|
||||
# we've already checked that there is a (valid) password field
|
||||
assert isinstance(password, str)
|
||||
(
|
||||
canonical_user_id,
|
||||
callback_3pid,
|
||||
) = await self.check_password_provider_3pid(medium, address, password)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
return canonical_user_id, callback_3pid
|
||||
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
# We mark that we've failed to log in here, as
|
||||
# `check_password_provider_3pid` might have returned `None` due
|
||||
# to an incorrect password, rather than the account not
|
||||
# existing.
|
||||
#
|
||||
# If it returned None but the 3PID was bound then we won't hit
|
||||
# this code path, which is fine as then the per-user ratelimit
|
||||
# will kick in below.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
(medium, address)
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier_dict = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
if identifier_dict["type"] != "m.id.user":
|
||||
raise SynapseError(400, "Unknown login identifier type")
|
||||
|
||||
username = identifier_dict.get("user")
|
||||
if not username:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
# Check if we've hit the failed ratelimit (but don't update it)
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(), update=False
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._validate_userid_login(username, login_submission)
|
||||
except LoginError:
|
||||
# The user has failed to log in, so we need to update the rate
|
||||
# limiter. Using `can_do_action` avoids us raising a ratelimit
|
||||
# exception and masking the LoginError. The actual ratelimiting
|
||||
# should have happened above.
|
||||
if ratelimit:
|
||||
self._failed_login_attempts_ratelimiter.can_do_action(
|
||||
qualified_user_id.lower()
|
||||
)
|
||||
raise
|
||||
|
||||
async def _validate_userid_login(
|
||||
self, username: str, login_submission: Dict[str, Any],
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
"""Helper for validate_login
|
||||
|
||||
Handles login, once we've mapped 3pids onto userids
|
||||
|
||||
Args:
|
||||
username: the username, from the identifier dict
|
||||
login_submission: the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
Returns:
|
||||
|
@ -843,38 +991,18 @@ class AuthHandler(BaseHandler):
|
|||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
login_type = login_submission.get("type")
|
||||
# we already checked that we have a valid login type
|
||||
assert isinstance(login_type, str)
|
||||
|
||||
known_login_type = False
|
||||
|
||||
# special case to check for "password" for the check_password interface
|
||||
# for the auth providers
|
||||
password = login_submission.get("password")
|
||||
|
||||
if login_type == LoginType.PASSWORD:
|
||||
if not self._password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
if not password:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
||||
known_login_type = True
|
||||
is_valid = await provider.check_password(qualified_user_id, password)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
||||
provider, "check_auth"
|
||||
):
|
||||
# this password provider doesn't understand custom login types
|
||||
continue
|
||||
|
||||
supported_login_types = provider.get_supported_login_types()
|
||||
if login_type not in supported_login_types:
|
||||
# this password provider doesn't understand this login type
|
||||
|
@ -899,15 +1027,17 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
result = await provider.check_auth(username, login_type, login_dict)
|
||||
if result:
|
||||
if isinstance(result, str):
|
||||
result = (result, None)
|
||||
return result
|
||||
|
||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||
known_login_type = True
|
||||
|
||||
# we've already checked that there is a (valid) password field
|
||||
password = login_submission["password"]
|
||||
assert isinstance(password, str)
|
||||
|
||||
canonical_user_id = await self._check_local_password(
|
||||
qualified_user_id, password # type: ignore
|
||||
qualified_user_id, password
|
||||
)
|
||||
|
||||
if canonical_user_id:
|
||||
|
@ -938,19 +1068,9 @@ class AuthHandler(BaseHandler):
|
|||
unsuccessful, `user_id` and `callback` are both `None`.
|
||||
"""
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "check_3pid_auth"):
|
||||
# This function is able to return a deferred that either
|
||||
# resolves None, meaning authentication failure, or upon
|
||||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = await provider.check_3pid_auth(medium, address, password)
|
||||
if result:
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
result = (result, None)
|
||||
return result
|
||||
result = await provider.check_3pid_auth(medium, address, password)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return None, None
|
||||
|
||||
|
@ -1008,16 +1128,11 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = provider.on_logged_out(
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
await provider.on_logged_out(
|
||||
user_id=user_info.user_id,
|
||||
device_id=user_info.device_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info.token_id is not None:
|
||||
|
@ -1046,11 +1161,10 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
for token, token_id, device_id in tokens_and_devices:
|
||||
await provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
for token, token_id, device_id in tokens_and_devices:
|
||||
await provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
|
||||
# delete pushers associated with the access tokens
|
||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
|
@ -1374,3 +1488,127 @@ class MacaroonGenerator:
|
|||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
return macaroon
|
||||
|
||||
|
||||
class PasswordProvider:
|
||||
"""Wrapper for a password auth provider module
|
||||
|
||||
This class abstracts out all of the backwards-compatibility hacks for
|
||||
password providers, to provide a consistent interface.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||
try:
|
||||
pp = module(config=config, account_handler=module_api)
|
||||
except Exception as e:
|
||||
logger.error("Error while initializing %r: %s", module, e)
|
||||
raise
|
||||
return cls(pp, module_api)
|
||||
|
||||
def __init__(self, pp, module_api: ModuleApi):
|
||||
self._pp = pp
|
||||
self._module_api = module_api
|
||||
|
||||
self._supported_login_types = {}
|
||||
|
||||
# grandfather in check_password support
|
||||
if hasattr(self._pp, "check_password"):
|
||||
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||
|
||||
g = getattr(self._pp, "get_supported_login_types", None)
|
||||
if g:
|
||||
self._supported_login_types.update(g())
|
||||
|
||||
def __str__(self):
|
||||
return str(self._pp)
|
||||
|
||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||
"""Get the login types supported by this password provider
|
||||
|
||||
Returns a map from a login type identifier (such as m.login.password) to an
|
||||
iterable giving the fields which must be provided by the user in the submission
|
||||
to the /login API.
|
||||
|
||||
This wrapper adds m.login.password to the list if the underlying password
|
||||
provider supports the check_password() api.
|
||||
"""
|
||||
return self._supported_login_types
|
||||
|
||||
async def check_auth(
|
||||
self, username: str, login_type: str, login_dict: JsonDict
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
"""Check if the user has presented valid login credentials
|
||||
|
||||
This wrapper also calls check_password() if the underlying password provider
|
||||
supports the check_password() api and the login type is m.login.password.
|
||||
|
||||
Args:
|
||||
username: user id presented by the client. Either an MXID or an unqualified
|
||||
username.
|
||||
|
||||
login_type: the login type being attempted - one of the types returned by
|
||||
get_supported_login_types()
|
||||
|
||||
login_dict: the dictionary of login secrets passed by the client.
|
||||
|
||||
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
|
||||
user, and `callback` is an optional callback which will be called with the
|
||||
result from the /login call (including access_token, device_id, etc.)
|
||||
"""
|
||||
# first grandfather in a call to check_password
|
||||
if login_type == LoginType.PASSWORD:
|
||||
g = getattr(self._pp, "check_password", None)
|
||||
if g:
|
||||
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||
is_valid = await self._pp.check_password(
|
||||
qualified_user_id, login_dict["password"]
|
||||
)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
g = getattr(self._pp, "check_auth", None)
|
||||
if not g:
|
||||
return None
|
||||
result = await g(username, login_type, login_dict)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
async def check_3pid_auth(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||
g = getattr(self._pp, "check_3pid_auth", None)
|
||||
if not g:
|
||||
return None
|
||||
|
||||
# This function is able to return a deferred that either
|
||||
# resolves None, meaning authentication failure, or upon
|
||||
# success, to a str (which is the user_id) or a tuple of
|
||||
# (user_id, callback_func), where callback_func should be run
|
||||
# after we've finished everything else
|
||||
result = await g(medium, address, password)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
# If it's a str, set callback function to None
|
||||
return result, None
|
||||
|
||||
return result
|
||||
|
||||
async def on_logged_out(
|
||||
self, user_id: str, device_id: Optional[str], access_token: str
|
||||
) -> None:
|
||||
g = getattr(self._pp, "on_logged_out", None)
|
||||
if not g:
|
||||
return
|
||||
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
|
|
@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
|
|||
raise SynapseError(500, "An error was encountered when sending the email")
|
||||
|
||||
token_expires = (
|
||||
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
|
||||
self.hs.get_clock().time_msec()
|
||||
+ self.hs.config.email_validation_token_lifetime
|
||||
)
|
||||
|
||||
await self.store.start_or_continue_validation_session(
|
||||
|
|
|
@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
|
|||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict, map_username_to_mxid_localpart
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
|
|||
|
||||
return UserAttributes(**attributes)
|
||||
|
||||
async def grandfather_existing_users() -> Optional[str]:
|
||||
if self._allow_existing_users:
|
||||
# If allowing existing users we want to generate a single localpart
|
||||
# and attempt to match it.
|
||||
attributes = await oidc_response_to_user_attributes(failures=0)
|
||||
|
||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
# If an existing matrix ID is returned, then use it.
|
||||
if len(users) == 1:
|
||||
previously_registered_user_id = next(iter(users))
|
||||
elif user_id in users:
|
||||
previously_registered_user_id = user_id
|
||||
else:
|
||||
# Do not attempt to continue generating Matrix IDs.
|
||||
raise MappingException(
|
||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||
user_id, users
|
||||
)
|
||||
)
|
||||
|
||||
return previously_registered_user_id
|
||||
|
||||
return None
|
||||
|
||||
return await self._sso_handler.get_mxid_from_sso(
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
user_agent,
|
||||
ip_address,
|
||||
oidc_response_to_user_attributes,
|
||||
self._allow_existing_users,
|
||||
grandfather_existing_users,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -299,17 +299,22 @@ class PaginationHandler:
|
|||
"""
|
||||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
async def purge_room(self, room_id: str) -> None:
|
||||
"""Purge the given room from the database"""
|
||||
async def purge_room(self, room_id: str, force: bool = False) -> None:
|
||||
"""Purge the given room from the database.
|
||||
|
||||
Args:
|
||||
room_id: room to be purged
|
||||
force: set true to skip checking for joined users.
|
||||
"""
|
||||
with await self.pagination_lock.write(room_id):
|
||||
# check we know about the room
|
||||
await self.store.get_room_version_id(room_id)
|
||||
|
||||
# first check that we have no users in this room
|
||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||
|
||||
if joined:
|
||||
raise SynapseError(400, "Users are still joined to this room")
|
||||
if not force:
|
||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||
if joined:
|
||||
raise SynapseError(400, "Users are still joined to this room")
|
||||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
|
||||
|
|
|
@ -366,7 +366,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
# later on.
|
||||
content = dict(content)
|
||||
|
||||
if not self.allow_per_room_profiles or requester.shadow_banned:
|
||||
# allow the server notices mxid to set room-level profile
|
||||
is_requester_server_notices_user = (
|
||||
self._server_notices_mxid is not None
|
||||
and requester.user.to_string() == self._server_notices_mxid
|
||||
)
|
||||
|
||||
if (
|
||||
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
||||
) or requester.shadow_banned:
|
||||
# Strip profile data, knowing that new profile data will be added to the
|
||||
# event's content in event_creation_handler.create_event() using the target's
|
||||
# global profile.
|
||||
|
|
|
@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
|
|||
return UserAttributes(
|
||||
localpart=result.get("mxid_localpart"),
|
||||
display_name=result.get("displayname"),
|
||||
emails=result.get("emails"),
|
||||
emails=result.get("emails", []),
|
||||
)
|
||||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
async def grandfather_existing_users() -> Optional[str]:
|
||||
# backwards-compatibility hack: see if there is an existing user with a
|
||||
# suitable mapping from the uid
|
||||
if (
|
||||
|
@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
|
|||
if users:
|
||||
registered_user_id = list(users.keys())[0]
|
||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
return None
|
||||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
return await self._sso_handler.get_mxid_from_sso(
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
user_agent,
|
||||
ip_address,
|
||||
saml_response_to_remapped_user_attributes,
|
||||
grandfather_existing_users,
|
||||
)
|
||||
|
||||
def expire_sessions(self):
|
||||
|
|
|
@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
|
|||
user_agent: str,
|
||||
ip_address: str,
|
||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||
allow_existing_users: bool = False,
|
||||
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
||||
) -> str:
|
||||
"""
|
||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||
|
@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
|
|||
if it has that matrix ID is returned regardless of the current mapping
|
||||
logic.
|
||||
|
||||
If a callable is provided for grandfathering users, it is called and can
|
||||
potentially return a matrix ID to use. If it does, the SSO ID is linked to
|
||||
this matrix ID for subsequent calls.
|
||||
|
||||
The mapping function is called (potentially multiple times) to generate
|
||||
a localpart for the user.
|
||||
|
||||
|
@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
|
|||
given user-agent and IP address and the SSO ID is linked to this matrix
|
||||
ID for subsequent calls.
|
||||
|
||||
If allow_existing_users is true the mapping function is only called once
|
||||
and results in:
|
||||
|
||||
1. The use of a previously registered matrix ID. In this case, the
|
||||
SSO ID is linked to the matrix ID. (Note it is possible that
|
||||
other SSO IDs are linked to the same matrix ID.)
|
||||
2. An unused localpart, in which case the user is registered (as
|
||||
discussed above).
|
||||
3. An error if the generated localpart matches multiple pre-existing
|
||||
matrix IDs. Generally this should not happen.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
|
@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
|
|||
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
||||
The only parameter is an integer which represents the amount of
|
||||
times the returned mxid localpart mapping has failed.
|
||||
allow_existing_users: True if the localpart returned from the
|
||||
mapping provider can be linked to an existing matrix ID.
|
||||
grandfather_existing_users: A callable which can return an previously
|
||||
existing matrix ID. The SSO ID is then linked to the returned
|
||||
matrix ID.
|
||||
|
||||
Returns:
|
||||
The user ID associated with the SSO response.
|
||||
|
@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
|
|||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
|
||||
# Check for grandfathering of users.
|
||||
if grandfather_existing_users:
|
||||
previously_registered_user_id = await grandfather_existing_users()
|
||||
if previously_registered_user_id:
|
||||
# Future logins should also match this user ID.
|
||||
await self.store.record_user_external_id(
|
||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||
)
|
||||
return previously_registered_user_id
|
||||
|
||||
# Otherwise, generate a new user.
|
||||
for i in range(self._MAP_USERNAME_RETRIES):
|
||||
try:
|
||||
|
@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
|
|||
|
||||
# Check if this mxid already exists
|
||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
# Note, if allow_existing_users is true then the loop is guaranteed
|
||||
# to end on the first iteration: either by matching an existing user,
|
||||
# raising an error, or registering a new user. See the docstring for
|
||||
# more in-depth an explanation.
|
||||
if users and allow_existing_users:
|
||||
# If an existing matrix ID is returned, then use it.
|
||||
if len(users) == 1:
|
||||
previously_registered_user_id = next(iter(users))
|
||||
elif user_id in users:
|
||||
previously_registered_user_id = user_id
|
||||
else:
|
||||
# Do not attempt to continue generating Matrix IDs.
|
||||
raise MappingException(
|
||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||
user_id, users
|
||||
)
|
||||
)
|
||||
|
||||
# Future logins should also match this user ID.
|
||||
await self.store.record_user_external_id(
|
||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||
)
|
||||
|
||||
return previously_registered_user_id
|
||||
|
||||
elif not users:
|
||||
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
||||
# This mxid is free
|
||||
break
|
||||
else:
|
||||
|
|
|
@ -25,7 +25,7 @@ from io import BytesIO
|
|||
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
|
||||
from canonicaljson import iterencode_canonical_json
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, interfaces
|
||||
|
@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
|||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
error_code,
|
||||
error_dict,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
request, error_code, error_dict, send_cors=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
|
|||
code,
|
||||
response_object,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
canonical_json=self.canonical_json,
|
||||
)
|
||||
|
||||
|
@ -587,7 +582,6 @@ def respond_with_json(
|
|||
code: int,
|
||||
json_object: Any,
|
||||
send_cors: bool = False,
|
||||
pretty_print: bool = False,
|
||||
canonical_json: bool = True,
|
||||
):
|
||||
"""Sends encoded JSON in response to the given request.
|
||||
|
@ -598,8 +592,6 @@ def respond_with_json(
|
|||
json_object: The object to serialize to JSON.
|
||||
send_cors: Whether to send Cross-Origin Resource Sharing headers
|
||||
https://fetch.spec.whatwg.org/#http-cors-protocol
|
||||
pretty_print: Whether to include indentation and line-breaks in the
|
||||
resulting JSON bytes.
|
||||
canonical_json: Whether to use the canonicaljson algorithm when encoding
|
||||
the JSON bytes.
|
||||
|
||||
|
@ -615,13 +607,10 @@ def respond_with_json(
|
|||
)
|
||||
return None
|
||||
|
||||
if pretty_print:
|
||||
encoder = iterencode_pretty_printed_json
|
||||
if canonical_json:
|
||||
encoder = iterencode_canonical_json
|
||||
else:
|
||||
if canonical_json:
|
||||
encoder = iterencode_canonical_json
|
||||
else:
|
||||
encoder = _encode_json_bytes
|
||||
encoder = _encode_json_bytes
|
||||
|
||||
request.setResponseCode(code)
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
|
@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
|
|||
)
|
||||
request.setHeader(
|
||||
b"Access-Control-Allow-Headers",
|
||||
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
|
||||
b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
|
||||
)
|
||||
|
||||
|
||||
|
@ -759,11 +748,3 @@ def finish_request(request: Request):
|
|||
request.finish()
|
||||
except RuntimeError as e:
|
||||
logger.info("Connection disconnected before response was written: %r", e)
|
||||
|
||||
|
||||
def _request_user_agent_is_curl(request: Request) -> bool:
|
||||
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
|
||||
for user_agent in user_agents:
|
||||
if b"curl" in user_agent:
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -75,6 +75,7 @@ class HttpPusher:
|
|||
self.failing_since = pusherdict["failing_since"]
|
||||
self.timed_call = None
|
||||
self._is_processing = False
|
||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||
|
||||
# This is the highest stream ordering we know it's safe to process.
|
||||
# When new events arrive, we'll be given a window of new events: we
|
||||
|
@ -140,7 +141,11 @@ class HttpPusher:
|
|||
async def _update_badge(self):
|
||||
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
||||
# to be largely redundant. perhaps we can remove it.
|
||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
badge = await push_tools.get_badge_count(
|
||||
self.hs.get_datastore(),
|
||||
self.user_id,
|
||||
group_by_room=self._group_unread_count_by_room,
|
||||
)
|
||||
await self._send_badge(badge)
|
||||
|
||||
def on_timer(self):
|
||||
|
@ -287,7 +292,11 @@ class HttpPusher:
|
|||
return True
|
||||
|
||||
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
|
||||
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
badge = await push_tools.get_badge_count(
|
||||
self.hs.get_datastore(),
|
||||
self.user_id,
|
||||
group_by_room=self._group_unread_count_by_room,
|
||||
)
|
||||
|
||||
event = await self.store.get_event(push_action["event_id"], allow_none=True)
|
||||
if event is None:
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
async def get_badge_count(store, user_id):
|
||||
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
|
||||
invites = await store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = await store.get_rooms_for_user(user_id)
|
||||
|
||||
|
@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
|
|||
room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
# return one badge count per conversation, as count per
|
||||
# message is so noisy as to be almost useless
|
||||
badge += 1 if notifs["notify_count"] else 0
|
||||
if notifs["notify_count"] == 0:
|
||||
continue
|
||||
|
||||
if group_by_room:
|
||||
# return one badge count per conversation
|
||||
badge += 1
|
||||
else:
|
||||
# increment the badge count by the number of unread messages in the room
|
||||
badge += notifs["notify_count"]
|
||||
return badge
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(
|
||||
requester, room_id, user_id, remote_room_hosts, content
|
||||
):
|
||||
async def _serialize_payload( # type: ignore
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
remote_room_hosts: List[str],
|
||||
content: JsonDict,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
requester(Requester)
|
||||
room_id (str)
|
||||
user_id (str)
|
||||
remote_room_hosts (list[str]): Servers to try and join via
|
||||
content(dict): The event content to use for the join event
|
||||
requester: The user making the request according to the access token
|
||||
room_id: The ID of the room.
|
||||
user_id: The ID of the user.
|
||||
remote_room_hosts: Servers to try and join via
|
||||
content: The event content to use for the join event
|
||||
|
||||
Returns:
|
||||
A dict representing the payload of the request.
|
||||
"""
|
||||
return {
|
||||
"requester": requester.serialize(),
|
||||
|
@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, room_id, user_id):
|
||||
async def _handle_request( # type: ignore
|
||||
self, request: Request, room_id: str, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
remote_room_hosts = content["remote_room_hosts"]
|
||||
|
@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
content: JsonDict,
|
||||
):
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
invite_event_id: ID of the invite to be rejected
|
||||
txn_id: optional transaction ID supplied by the client
|
||||
requester: user making the rejection request, according to the access token
|
||||
content: additional content to include in the rejection event.
|
||||
invite_event_id: The ID of the invite to be rejected.
|
||||
txn_id: Optional transaction ID supplied by the client
|
||||
requester: User making the rejection request, according to the access token
|
||||
content: Additional content to include in the rejection event.
|
||||
Normally an empty dict.
|
||||
|
||||
Returns:
|
||||
A dict representing the payload of the request.
|
||||
"""
|
||||
return {
|
||||
"txn_id": txn_id,
|
||||
|
@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, invite_event_id):
|
||||
async def _handle_request( # type: ignore
|
||||
self, request: Request, invite_event_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
txn_id = content["txn_id"]
|
||||
|
@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||
self.distributor = hs.get_distributor()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(room_id, user_id, change):
|
||||
async def _serialize_payload( # type: ignore
|
||||
room_id: str, user_id: str, change: str
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
room_id (str)
|
||||
user_id (str)
|
||||
change (str): "left"
|
||||
room_id: The ID of the room.
|
||||
user_id: The ID of the user.
|
||||
change: "left"
|
||||
|
||||
Returns:
|
||||
A dict representing the payload of the request.
|
||||
"""
|
||||
assert change == "left"
|
||||
|
||||
return {}
|
||||
|
||||
def _handle_request(self, request, room_id, user_id, change):
|
||||
def _handle_request( # type: ignore
|
||||
self, request: Request, room_id: str, user_id: str, change: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
|
|
|
@ -70,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
|
||||
|
||||
class DeleteRoomRestServlet(RestServlet):
|
||||
"""Delete a room from server. It is a combination and improvement of
|
||||
shut down and purge room.
|
||||
"""Delete a room from server.
|
||||
|
||||
It is a combination and improvement of shutdown and purge room.
|
||||
|
||||
Shuts down a room by removing all local users from the room.
|
||||
Blocking all future invites and joins to the room is optional.
|
||||
|
||||
If desired any local aliases will be repointed to a new room
|
||||
created by `new_room_user_id` and kicked users will be auto
|
||||
created by `new_room_user_id` and kicked users will be auto-
|
||||
joined to the new room.
|
||||
It will remove all trace of a room from the database.
|
||||
|
||||
If 'purge' is true, it will remove all traces of a room from the database.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
|
||||
|
@ -110,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
|
|||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
force_purge = content.get("force_purge", False)
|
||||
if not isinstance(force_purge, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Param 'force_purge' must be a boolean, if given",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
ret = await self.room_shutdown_handler.shutdown_room(
|
||||
room_id=room_id,
|
||||
new_room_user_id=content.get("new_room_user_id"),
|
||||
|
@ -121,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
|
|||
|
||||
# Purge room
|
||||
if purge:
|
||||
await self.pagination_handler.purge_room(room_id)
|
||||
await self.pagination_handler.purge_room(room_id, force=force_purge)
|
||||
|
||||
return (200, ret)
|
||||
|
||||
|
|
|
@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
|
|||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers.auth import (
|
||||
convert_client_dict_legacy_fields_to_identifier,
|
||||
login_id_phone_to_thirdparty,
|
||||
)
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
|
|||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
|
|||
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||
)
|
||||
self._failed_attempts_ratelimiter = Ratelimiter(
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
)
|
||||
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
flows = []
|
||||
|
@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
|
|||
result["well_known"] = well_known_data
|
||||
return 200, result
|
||||
|
||||
def _get_qualified_user_id(self, identifier):
|
||||
if identifier["type"] != "m.id.user":
|
||||
raise SynapseError(400, "Unknown login identifier type")
|
||||
if "user" not in identifier:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
if identifier["user"].startswith("@"):
|
||||
return identifier["user"]
|
||||
else:
|
||||
return UserID(identifier["user"], self.hs.hostname).to_string()
|
||||
|
||||
async def _do_appservice_login(
|
||||
self, login_submission: JsonDict, appservice: ApplicationService
|
||||
):
|
||||
logger.info(
|
||||
"Got appservice login request with identifier: %r",
|
||||
login_submission.get("identifier"),
|
||||
)
|
||||
identifier = login_submission.get("identifier")
|
||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||
|
||||
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
|
||||
qualified_user_id = self._get_qualified_user_id(identifier)
|
||||
if not isinstance(identifier, dict):
|
||||
raise SynapseError(
|
||||
400, "Invalid identifier in login submission", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
# this login flow only supports identifiers of type "m.id.user".
|
||||
if identifier.get("type") != "m.id.user":
|
||||
raise SynapseError(
|
||||
400, "Unknown login identifier type", Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
user = identifier.get("user")
|
||||
if not isinstance(user, str):
|
||||
raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
|
||||
|
||||
if user.startswith("@"):
|
||||
qualified_user_id = user
|
||||
else:
|
||||
qualified_user_id = UserID(user, self.hs.hostname).to_string()
|
||||
|
||||
if not appservice.is_interested_in_user(qualified_user_id):
|
||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||
|
@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
|
||||
|
||||
# convert phone type identifiers to generic threepids
|
||||
if identifier["type"] == "m.id.phone":
|
||||
identifier = login_id_phone_to_thirdparty(identifier)
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
if medium == "email":
|
||||
try:
|
||||
address = canonicalise_email(address)
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
# We also apply account rate limiting using the 3PID as a key, as
|
||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
|
||||
|
||||
# Check for login providers that support 3pid login types
|
||||
(
|
||||
canonical_user_id,
|
||||
callback_3pid,
|
||||
) = await self.auth_handler.check_password_provider_3pid(
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
|
||||
result = await self._complete_login(
|
||||
canonical_user_id, login_submission, callback_3pid
|
||||
)
|
||||
return result
|
||||
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
# We mark that we've failed to log in here, as
|
||||
# `check_password_provider_3pid` might have returned `None` due
|
||||
# to an incorrect password, rather than the account not
|
||||
# existing.
|
||||
#
|
||||
# If it returned None but the 3PID was bound then we won't hit
|
||||
# this code path, which is fine as then the per-user ratelimit
|
||||
# will kick in below.
|
||||
self._failed_attempts_ratelimiter.can_do_action((medium, address))
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
qualified_user_id = self._get_qualified_user_id(identifier)
|
||||
|
||||
# Check if we've hit the failed ratelimit (but don't update it)
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(), update=False
|
||||
canonical_user_id, callback = await self.auth_handler.validate_login(
|
||||
login_submission, ratelimit=True
|
||||
)
|
||||
|
||||
try:
|
||||
canonical_user_id, callback = await self.auth_handler.validate_login(
|
||||
identifier["user"], login_submission
|
||||
)
|
||||
except LoginError:
|
||||
# The user has failed to log in, so we need to update the rate
|
||||
# limiter. Using `can_do_action` avoids us raising a ratelimit
|
||||
# exception and masking the LoginError. The actual ratelimiting
|
||||
# should have happened above.
|
||||
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
|
||||
raise
|
||||
|
||||
result = await self._complete_login(
|
||||
canonical_user_id, login_submission, callback
|
||||
)
|
||||
|
|
|
@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
|
|
@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(
|
||||
|
|
|
@ -66,7 +66,7 @@ class LocalKey(Resource):
|
|||
|
||||
def __init__(self, hs):
|
||||
self.config = hs.config
|
||||
self.clock = hs.clock
|
||||
self.clock = hs.get_clock()
|
||||
self.update_response_body(self.clock.time_msec())
|
||||
Resource.__init__(self)
|
||||
|
||||
|
|
|
@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
|
|||
"@cache_in_self can only be used on functions starting with `get_`"
|
||||
)
|
||||
|
||||
depname = builder.__name__[len("get_") :]
|
||||
# get_attr -> _attr
|
||||
depname = builder.__name__[len("get") :]
|
||||
|
||||
building = [False]
|
||||
|
||||
|
@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
self._instance_id = random_string(5)
|
||||
self._instance_name = config.worker_name or "master"
|
||||
|
||||
self.clock = Clock(reactor)
|
||||
self.distributor = Distributor()
|
||||
|
||||
self.registration_ratelimiter = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=config.rc_registration.per_second,
|
||||
burst_count=config.rc_registration.burst_count,
|
||||
)
|
||||
|
||||
self.version_string = version_string
|
||||
|
||||
self.datastores = None # type: Optional[Databases]
|
||||
|
@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
def is_mine_id(self, string: str) -> bool:
|
||||
return string.split(":", 1)[1] == self.hostname
|
||||
|
||||
@cache_in_self
|
||||
def get_clock(self) -> Clock:
|
||||
return self.clock
|
||||
return Clock(self._reactor)
|
||||
|
||||
def get_datastore(self) -> DataStore:
|
||||
if not self.datastores:
|
||||
|
@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
def get_config(self) -> HomeServerConfig:
|
||||
return self.config
|
||||
|
||||
@cache_in_self
|
||||
def get_distributor(self) -> Distributor:
|
||||
return self.distributor
|
||||
return Distributor()
|
||||
|
||||
@cache_in_self
|
||||
def get_registration_ratelimiter(self) -> Ratelimiter:
|
||||
return self.registration_ratelimiter
|
||||
return Ratelimiter(
|
||||
clock=self.get_clock(),
|
||||
rate_hz=self.config.rc_registration.per_second,
|
||||
burst_count=self.config.rc_registration.burst_count,
|
||||
)
|
||||
|
||||
@cache_in_self
|
||||
def get_federation_client(self) -> FederationClient:
|
||||
|
@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
|
||||
@cache_in_self
|
||||
def get_federation_ratelimiter(self) -> FederationRateLimiter:
|
||||
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
|
||||
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
|
||||
|
||||
@cache_in_self
|
||||
def get_module_api(self) -> ModuleApi:
|
||||
|
|
|
@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
for table in (
|
||||
"event_auth",
|
||||
"event_edges",
|
||||
"event_json",
|
||||
"event_push_actions_staging",
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
|
@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
"destination_rooms",
|
||||
"event_backward_extremities",
|
||||
"event_forward_extremities",
|
||||
"event_json",
|
||||
"event_push_actions",
|
||||
"event_search",
|
||||
"events",
|
||||
|
|
|
@ -20,14 +20,14 @@
|
|||
*/
|
||||
|
||||
-- add new index that includes method to local media
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('local_media_repository_thumbnails_method_idx', '{}');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5807, 'local_media_repository_thumbnails_method_idx', '{}');
|
||||
|
||||
-- add new index that includes method to remote media
|
||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||
('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
|
||||
|
||||
-- drop old index
|
||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||
('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
|
||||
|
||||
|
|
|
@ -28,5 +28,5 @@
|
|||
-- functionality as the old one. This effectively restarts the background job
|
||||
-- from the beginning, without running it twice in a row, supporting both
|
||||
-- upgrade usecases.
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('populate_stats_process_rooms_2', '{}');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5812, 'populate_stats_process_rooms_2', '{}');
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('users_have_local_media', '{}');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5822, 'users_have_local_media', '{}');
|
||||
|
|
|
@ -13,5 +13,5 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('e2e_cross_signing_keys_idx', '{}');
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5823, 'e2e_cross_signing_keys_idx', '{}');
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- this index is essentially redundant. The only time it was ever used was when purging
|
||||
-- rooms - and Synapse 1.24 will change that.
|
||||
|
||||
DROP INDEX IF EXISTS event_json_room_id;
|
|
@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.fail("some_user was not in %s" % macaroon.inspect())
|
||||
|
||||
def test_macaroon_caveats(self):
|
||||
self.hs.clock.now = 5000
|
||||
self.hs.get_clock().now = 5000
|
||||
|
||||
token = self.macaroon_generator.generate_access_token("a_user")
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
self.hs.clock.now = 1000
|
||||
self.hs.get_clock().now = 1000
|
||||
|
||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||
user_id = yield defer.ensureDeferred(
|
||||
|
@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.assertEqual("a_user", user_id)
|
||||
|
||||
# when we advance the clock, the token should be rejected
|
||||
self.hs.clock.now = 6000
|
||||
self.hs.get_clock().now = 6000
|
||||
with self.assertRaises(synapse.api.errors.AuthError):
|
||||
yield defer.ensureDeferred(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
|
|
|
@ -23,7 +23,7 @@ import pymacaroons
|
|||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
|
||||
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.types import UserID
|
||||
|
||||
|
@ -127,13 +127,8 @@ async def get_json(url):
|
|||
|
||||
|
||||
class OidcHandlerTestCase(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.http_client = Mock(spec=["get_json"])
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
self.http_client.user_agent = "Synapse Test"
|
||||
|
||||
config = self.default_config()
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
oidc_config = {
|
||||
"enabled": True,
|
||||
|
@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
oidc_config.update(config.get("oidc_config", {}))
|
||||
config["oidc_config"] = oidc_config
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
http_client=self.http_client,
|
||||
proxied_http_client=self.http_client,
|
||||
config=config,
|
||||
)
|
||||
return config
|
||||
|
||||
self.handler = OidcHandler(hs)
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.http_client = Mock(spec=["get_json"])
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
self.http_client.user_agent = "Synapse Test"
|
||||
|
||||
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||
|
||||
self.handler = hs.get_oidc_handler()
|
||||
sso_handler = hs.get_sso_handler()
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
self.handler._sso_handler.render_error = self.render_error
|
||||
sso_handler.render_error = self.render_error
|
||||
|
||||
# Reduce the number of attempts when generating MXIDs.
|
||||
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
|
||||
return hs
|
||||
|
||||
|
@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
mxid = self.get_success(
|
||||
self.handler._map_userinfo_to_user(
|
||||
userinfo, token, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
||||
# requires a unique sub, but something that maps to the same matrix ID,
|
||||
# in this case we'll just use the same username. A more realistic example
|
||||
|
@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
self.assertEqual(mxid, "@test_user1:test")
|
||||
|
||||
# Register all of the potential users for a particular username.
|
||||
# Register all of the potential mxids for a particular OIDC username.
|
||||
self.get_success(
|
||||
store.register_user(user_id="@tester:test", password_hash=None)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,580 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the password_auth_provider interface"""
|
||||
|
||||
from typing import Any, Type, Union
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.client.v2_alpha import devices
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.unittest import override_config
|
||||
|
||||
# (possibly experimental) login flows we expect to appear in the list after the normal
|
||||
# ones
|
||||
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
|
||||
|
||||
# a mock instance which the dummy auth providers delegate to, so we can see what's going
|
||||
# on
|
||||
mock_password_provider = Mock()
|
||||
|
||||
|
||||
class PasswordOnlyAuthProvider:
|
||||
"""A password_provider which only implements `check_password`."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
pass
|
||||
|
||||
def check_password(self, *args):
|
||||
return mock_password_provider.check_password(*args)
|
||||
|
||||
|
||||
class CustomAuthProvider:
|
||||
"""A password_provider which implements a custom login type."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
pass
|
||||
|
||||
def get_supported_login_types(self):
|
||||
return {"test.login_type": ["test_field"]}
|
||||
|
||||
def check_auth(self, *args):
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
|
||||
class PasswordCustomAuthProvider:
|
||||
"""A password_provider which implements password login via `check_auth`, as well
|
||||
as a custom type."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(self):
|
||||
pass
|
||||
|
||||
def __init__(self, config, account_handler):
|
||||
pass
|
||||
|
||||
def get_supported_login_types(self):
|
||||
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
|
||||
|
||||
def check_auth(self, *args):
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
|
||||
def providers_config(*providers: Type[Any]) -> dict:
|
||||
"""Returns a config dict that will enable the given password auth providers"""
|
||||
return {
|
||||
"password_providers": [
|
||||
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
|
||||
for provider in providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
devices.register_servlets,
|
||||
]
|
||||
|
||||
def setUp(self):
|
||||
# we use a global mock device, so make sure we are starting with a clean slate
|
||||
mock_password_provider.reset_mock()
|
||||
super().setUp()
|
||||
|
||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_provider_login(self):
|
||||
# login flows should only have m.login.password
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# login with mxid should work too
|
||||
channel = self._send_password_login("@u:bz", "p")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@ USER🙂NAME :test", " pASS😢word "
|
||||
)
|
||||
|
||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||
def test_password_only_auth_provider_ui_auth(self):
|
||||
"""UI Auth should delegate correctly to the password provider"""
|
||||
|
||||
# create the user, otherwise access doesn't work
|
||||
module_api = self.hs.get_module_api()
|
||||
self.get_success(module_api.register_user("u"))
|
||||
|
||||
# log in twice, to get two devices
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
tok1 = self.login("u", "p")
|
||||
self.login("u", "p", device_id="dev2")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# have the auth provider deny the request to start with
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
|
||||
# make the initial request which returns a 401
|
||||
session = self._start_delete_device_session(tok1, "dev2")
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# Finally, check the request goes through when we allow it
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
|
||||
self.assertEqual(channel.code, 200)
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
|
||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||
def test_local_user_fallback_login(self):
|
||||
"""rejected login should fall back to local db"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||
|
||||
@override_config(providers_config(PasswordOnlyAuthProvider))
|
||||
def test_local_user_fallback_ui_auth(self):
|
||||
"""rejected login should fall back to local db"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# have the auth provider deny the request
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
|
||||
# log in twice, to get two devices
|
||||
tok1 = self.login("localuser", "localpass")
|
||||
self.login("localuser", "localpass", device_id="dev2")
|
||||
mock_password_provider.check_password.reset_mock()
|
||||
|
||||
# first delete should give a 401
|
||||
session = self._start_delete_device_session(tok1, "dev2")
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
# Wrong password
|
||||
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
|
||||
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@localuser:test", "xxx"
|
||||
)
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# Right password
|
||||
channel = self._authed_delete_device(
|
||||
tok1, "dev2", session, "localuser", "localpass"
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@localuser:test", "localpass"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordOnlyAuthProvider),
|
||||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_no_local_user_fallback_login(self):
|
||||
"""localdb_enabled can block login with the local password
|
||||
"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@localuser:test", "localpass"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordOnlyAuthProvider),
|
||||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_no_local_user_fallback_ui_auth(self):
|
||||
"""localdb_enabled can block ui auth with the local password
|
||||
"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
# allow login via the auth provider
|
||||
mock_password_provider.check_password.return_value = defer.succeed(True)
|
||||
|
||||
# log in twice, to get two devices
|
||||
tok1 = self.login("localuser", "p")
|
||||
self.login("localuser", "p", device_id="dev2")
|
||||
mock_password_provider.check_password.reset_mock()
|
||||
|
||||
# first delete should give a 401
|
||||
channel = self._delete_device(tok1, "dev2")
|
||||
self.assertEqual(channel.code, 401)
|
||||
# m.login.password UIA is permitted because the auth provider allows it,
|
||||
# even though the localdb does not.
|
||||
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
|
||||
session = channel.json_body["session"]
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
# now try deleting with the local password
|
||||
mock_password_provider.check_password.return_value = defer.succeed(False)
|
||||
channel = self._authed_delete_device(
|
||||
tok1, "dev2", session, "localuser", "localpass"
|
||||
)
|
||||
self.assertEqual(channel.code, 401) # XXX why not a 403?
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@localuser:test", "localpass"
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordOnlyAuthProvider),
|
||||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_auth_disabled(self):
|
||||
"""password auth doesn't work if it's disabled across the board"""
|
||||
# login flows should be empty
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_login(self):
|
||||
# login flows should have the custom flow and m.login.password, since we
|
||||
# haven't disabled local password lookup.
|
||||
# (password must come first, because reasons)
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(
|
||||
flows,
|
||||
[{"type": "m.login.password"}, {"type": "test.login_type"}]
|
||||
+ ADDITIONAL_LOGIN_FLOWS,
|
||||
)
|
||||
|
||||
# login with missing param should be rejected
|
||||
channel = self._send_login("test.login_type", "u")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
)
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# try a weird username. Again, it's unclear what we *expect* to happen
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
"@ MALFORMED! :bz"
|
||||
)
|
||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||
)
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_ui_auth(self):
|
||||
# register the user and log in twice, to get two devices
|
||||
self.register_user("localuser", "localpass")
|
||||
tok1 = self.login("localuser", "localpass")
|
||||
self.login("localuser", "localpass", device_id="dev2")
|
||||
|
||||
# make the initial request which returns a 401
|
||||
channel = self._delete_device(tok1, "dev2")
|
||||
self.assertEqual(channel.code, 401)
|
||||
# Ensure that flows are what is expected.
|
||||
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# missing param
|
||||
body = {
|
||||
"auth": {
|
||||
"type": "test.login_type",
|
||||
"identifier": {"type": "m.id.user", "user": "localuser"},
|
||||
"session": session,
|
||||
},
|
||||
}
|
||||
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
self.assertEqual(channel.code, 400)
|
||||
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
|
||||
# use it...
|
||||
self.assertIn("Missing parameters", channel.json_body["error"])
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# right params, but authing as the wrong user
|
||||
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
|
||||
body["auth"]["test_field"] = "foo"
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"localuser", "test.login_type", {"test_field": "foo"}
|
||||
)
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# and finally, succeed
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
"@localuser:test"
|
||||
)
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
self.assertEqual(channel.code, 200)
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"localuser", "test.login_type", {"test_field": "foo"}
|
||||
)
|
||||
|
||||
@override_config(providers_config(CustomAuthProvider))
|
||||
def test_custom_auth_provider_callback(self):
|
||||
callback = Mock(return_value=defer.succeed(None))
|
||||
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
("@user:bz", callback)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
)
|
||||
|
||||
# check the args to the callback
|
||||
callback.assert_called_once()
|
||||
call_args, call_kwargs = callback.call_args
|
||||
# should be one positional arg
|
||||
self.assertEqual(len(call_args), 1)
|
||||
self.assertEqual(call_args[0]["user_id"], "@user:bz")
|
||||
for p in ["user_id", "access_token", "device_id", "home_server"]:
|
||||
self.assertIn(p, call_args[0])
|
||||
|
||||
@override_config(
|
||||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
|
||||
)
|
||||
def test_custom_auth_password_disabled(self):
|
||||
"""Test login with a custom auth provider where password login is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordCustomAuthProvider),
|
||||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_login(self):
|
||||
"""log in with a custom auth provider which implements password, but password
|
||||
login is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(PasswordCustomAuthProvider),
|
||||
"password_config": {"enabled": False},
|
||||
}
|
||||
)
|
||||
def test_password_custom_auth_password_disabled_ui_auth(self):
|
||||
"""UI Auth with a custom auth provider which implements password, but password
|
||||
login is disabled"""
|
||||
# register the user and log in twice via the test login type to get two devices,
|
||||
self.register_user("localuser", "localpass")
|
||||
mock_password_provider.check_auth.return_value = defer.succeed(
|
||||
"@localuser:test"
|
||||
)
|
||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
tok1 = channel.json_body["access_token"]
|
||||
|
||||
channel = self._send_login(
|
||||
"test.login_type", "localuser", test_field="", device_id="dev2"
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# make the initial request which returns a 401
|
||||
channel = self._delete_device(tok1, "dev2")
|
||||
self.assertEqual(channel.code, 401)
|
||||
# Ensure that flows are what is expected. In particular, "password" should *not*
|
||||
# be present.
|
||||
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
|
||||
session = channel.json_body["session"]
|
||||
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# check that auth with password is rejected
|
||||
body = {
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": "localuser"},
|
||||
"password": "localpass",
|
||||
"session": session,
|
||||
},
|
||||
}
|
||||
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(
|
||||
"Password login has been disabled.", channel.json_body["error"]
|
||||
)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# successful auth
|
||||
body["auth"]["type"] = "test.login_type"
|
||||
body["auth"]["test_field"] = "x"
|
||||
channel = self._delete_device(tok1, "dev2", body)
|
||||
self.assertEqual(channel.code, 200)
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"localuser", "test.login_type", {"test_field": "x"}
|
||||
)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
**providers_config(CustomAuthProvider),
|
||||
"password_config": {"localdb_enabled": False},
|
||||
}
|
||||
)
|
||||
def test_custom_auth_no_local_user_fallback(self):
|
||||
"""Test login with a custom auth provider where the local db is disabled"""
|
||||
self.register_user("localuser", "localpass")
|
||||
|
||||
flows = self._get_login_flows()
|
||||
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
|
||||
|
||||
# password login shouldn't work and should be rejected with a 400
|
||||
# ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def _get_login_flows(self) -> JsonDict:
|
||||
_, channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
return channel.json_body["flows"]
|
||||
|
||||
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||
return self._send_login(type="m.login.password", user=user, password=password)
|
||||
|
||||
def _send_login(self, type, user, **params) -> FakeChannel:
|
||||
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
|
||||
_, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
|
||||
return channel
|
||||
|
||||
def _start_delete_device_session(self, access_token, device_id) -> str:
|
||||
"""Make an initial delete device request, and return the UI Auth session ID"""
|
||||
channel = self._delete_device(access_token, device_id)
|
||||
self.assertEqual(channel.code, 401)
|
||||
# Ensure that flows are what is expected.
|
||||
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||
return channel.json_body["session"]
|
||||
|
||||
def _authed_delete_device(
|
||||
self,
|
||||
access_token: str,
|
||||
device_id: str,
|
||||
session: str,
|
||||
user_id: str,
|
||||
password: str,
|
||||
) -> FakeChannel:
|
||||
"""Make a delete device request, authenticating with the given uid/password"""
|
||||
return self._delete_device(
|
||||
access_token,
|
||||
device_id,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": user_id},
|
||||
"password": password,
|
||||
"session": session,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _delete_device(
|
||||
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
|
||||
) -> FakeChannel:
|
||||
"""Delete an individual device."""
|
||||
_, channel = self.make_request(
|
||||
"DELETE", "devices/" + device, body, access_token=access_token
|
||||
)
|
||||
return channel
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.handlers.sso import MappingException
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
BASE_URL = "https://synapse/"
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeAuthnResponse:
|
||||
ava = attr.ib(type=dict)
|
||||
|
||||
|
||||
class TestMappingProvider:
|
||||
def __init__(self, config, module):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config):
|
||||
return {"uid"}, {"displayName"}
|
||||
|
||||
def get_remote_user_id(self, saml_response, client_redirect_url):
|
||||
return saml_response.ava["uid"]
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response, failures, client_redirect_url
|
||||
):
|
||||
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
||||
return {"mxid_localpart": localpart, "displayname": None}
|
||||
|
||||
|
||||
class SamlHandlerTestCase(HomeserverTestCase):
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
saml_config = {
|
||||
"sp_config": {"metadata": {}},
|
||||
# Disable grandfathering.
|
||||
"grandfathered_mxid_source_attribute": None,
|
||||
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
||||
}
|
||||
|
||||
# Update this config with what's in the default config so that
|
||||
# override_config works as expected.
|
||||
saml_config.update(config.get("saml2_config", {}))
|
||||
config["saml2_config"] = saml_config
|
||||
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
self.handler = hs.get_saml_handler()
|
||||
|
||||
# Reduce the number of attempts when generating MXIDs.
|
||||
sso_handler = hs.get_sso_handler()
|
||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
|
||||
return hs
|
||||
|
||||
def test_map_saml_response_to_user(self):
|
||||
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||
# The redirect_url doesn't matter with the default user mapping provider.
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||
def test_map_saml_response_to_existing_user(self):
|
||||
"""Existing users can log in with SAML account."""
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
)
|
||||
|
||||
# Map a user via SSO.
|
||||
saml_response = FakeAuthnResponse(
|
||||
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
|
||||
)
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
self.assertEqual(mxid, "@test_user:test")
|
||||
|
||||
def test_map_saml_response_to_invalid_localpart(self):
|
||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
||||
redirect_url = ""
|
||||
e = self.get_failure(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||
|
||||
def test_map_saml_response_to_user_retries(self):
|
||||
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
)
|
||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||
redirect_url = ""
|
||||
mxid = self.get_success(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
)
|
||||
)
|
||||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
self.assertEqual(mxid, "@test_user1:test")
|
||||
|
||||
# Register all of the potential mxids for a particular SAML username.
|
||||
self.get_success(
|
||||
store.register_user(user_id="@tester:test", password_hash=None)
|
||||
)
|
||||
for i in range(1, 3):
|
||||
self.get_success(
|
||||
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
|
||||
)
|
||||
|
||||
# Now attempt to map to a username, this will fail since all potential usernames are taken.
|
||||
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
|
||||
e = self.get_failure(
|
||||
self.handler._map_saml_response_to_user(
|
||||
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||
),
|
||||
MappingException,
|
||||
)
|
||||
self.assertEqual(
|
||||
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||
)
|
|
@ -12,7 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
|
|||
import synapse.rest.admin
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import receipts
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
class HTTPPusherTests(HomeserverTestCase):
|
||||
|
@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
receipts.register_servlets,
|
||||
]
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
|
@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
|
||||
# check that this is low-priority
|
||||
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
|
||||
|
||||
def test_push_unread_count_group_by_room(self):
|
||||
"""
|
||||
The HTTP pusher will group unread count by number of unread rooms.
|
||||
"""
|
||||
# Carry out common push count tests and setup
|
||||
self._test_push_unread_count()
|
||||
|
||||
# Carry out our option-value specific test
|
||||
#
|
||||
# This push should still only contain an unread count of 1 (for 1 unread room)
|
||||
self.assertEqual(
|
||||
self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
|
||||
)
|
||||
|
||||
@override_config({"push": {"group_unread_count_by_room": False}})
|
||||
def test_push_unread_count_message_count(self):
|
||||
"""
|
||||
The HTTP pusher will send the total unread message count.
|
||||
"""
|
||||
# Carry out common push count tests and setup
|
||||
self._test_push_unread_count()
|
||||
|
||||
# Carry out our option-value specific test
|
||||
#
|
||||
# We're counting every unread message, so there should now be 4 since the
|
||||
# last read receipt
|
||||
self.assertEqual(
|
||||
self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
|
||||
)
|
||||
|
||||
def _test_push_unread_count(self):
|
||||
"""
|
||||
Tests that the correct unread count appears in sent push notifications
|
||||
|
||||
Note that:
|
||||
* Sending messages will cause push notifications to go out to relevant users
|
||||
* Sending a read receipt will cause a "badge update" notification to go out to
|
||||
the user that sent the receipt
|
||||
"""
|
||||
# Register the user who gets notified
|
||||
user_id = self.register_user("user", "pass")
|
||||
access_token = self.login("user", "pass")
|
||||
|
||||
# Register the user who sends the message
|
||||
other_user_id = self.register_user("other_user", "pass")
|
||||
other_access_token = self.login("other_user", "pass")
|
||||
|
||||
# Create a room (as other_user)
|
||||
room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
|
||||
|
||||
# The user to get notified joins
|
||||
self.helper.join(room=room_id, user=user_id, tok=access_token)
|
||||
|
||||
# Register the pusher
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
app_id="m.http",
|
||||
app_display_name="HTTP Push Notifications",
|
||||
device_display_name="pushy push",
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "example.com"},
|
||||
)
|
||||
)
|
||||
|
||||
# Send a message
|
||||
response = self.helper.send(
|
||||
room_id, body="Hello there!", tok=other_access_token
|
||||
)
|
||||
# To get an unread count, the user who is getting notified has to have a read
|
||||
# position in the room. We'll set the read position to this event in a moment
|
||||
first_message_event_id = response["event_id"]
|
||||
|
||||
# Advance time a bit (so the pusher will register something has happened) and
|
||||
# make the push succeed
|
||||
self.push_attempts[0][0].callback({})
|
||||
self.pump()
|
||||
|
||||
# Check our push made it
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
self.assertEqual(self.push_attempts[0][1], "example.com")
|
||||
|
||||
# Check that the unread count for the room is 0
|
||||
#
|
||||
# The unread count is zero as the user has no read receipt in the room yet
|
||||
self.assertEqual(
|
||||
self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
|
||||
)
|
||||
|
||||
# Now set the user's read receipt position to the first event
|
||||
#
|
||||
# This will actually trigger a new notification to be sent out so that
|
||||
# even if the user does not receive another message, their unread
|
||||
# count goes down
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
|
||||
{},
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# Advance time and make the push succeed
|
||||
self.push_attempts[1][0].callback({})
|
||||
self.pump()
|
||||
|
||||
# Unread count is still zero as we've read the only message in the room
|
||||
self.assertEqual(len(self.push_attempts), 2)
|
||||
self.assertEqual(
|
||||
self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
|
||||
)
|
||||
|
||||
# Send another message
|
||||
self.helper.send(
|
||||
room_id, body="How's the weather today?", tok=other_access_token
|
||||
)
|
||||
|
||||
# Advance time and make the push succeed
|
||||
self.push_attempts[2][0].callback({})
|
||||
self.pump()
|
||||
|
||||
# This push should contain an unread count of 1 as there's now been one
|
||||
# message since our last read receipt
|
||||
self.assertEqual(len(self.push_attempts), 3)
|
||||
self.assertEqual(
|
||||
self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
|
||||
)
|
||||
|
||||
# Since we're grouping by room, sending more messages shouldn't increase the
|
||||
# unread count, as they're all being sent in the same room
|
||||
self.helper.send(room_id, body="Hello?", tok=other_access_token)
|
||||
|
||||
# Advance time and make the push succeed
|
||||
self.pump()
|
||||
self.push_attempts[3][0].callback({})
|
||||
|
||||
self.helper.send(room_id, body="Hello??", tok=other_access_token)
|
||||
|
||||
# Advance time and make the push succeed
|
||||
self.pump()
|
||||
self.push_attempts[4][0].callback({})
|
||||
|
||||
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
|
||||
|
||||
# Advance time and make the push succeed
|
||||
self.pump()
|
||||
self.push_attempts[5][0].callback({})
|
||||
|
||||
self.assertEqual(len(self.push_attempts), 6)
|
||||
|
|
|
@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
|
||||
|
||||
self.test_handler = self._build_replication_data_handler()
|
||||
self.worker_hs.replication_data_handler = self.test_handler
|
||||
self.worker_hs._replication_data_handler = self.test_handler
|
||||
|
||||
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
||||
self.client = ClientReplicationStreamProtocol(
|
||||
|
|
|
@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
|
|||
self.handler = hs.get_device_handler()
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.clock
|
||||
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
|
|
@ -33,12 +33,15 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
"red", http_client=None, federation_client=Mock()
|
||||
)
|
||||
presence_handler = Mock()
|
||||
presence_handler.set_state.return_value = defer.succeed(None)
|
||||
|
||||
hs.presence_handler = Mock()
|
||||
hs.presence_handler.set_state.return_value = defer.succeed(None)
|
||||
hs = self.setup_test_homeserver(
|
||||
"red",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
presence_handler=presence_handler,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
|
@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
|
||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
|
||||
|
||||
def test_put_presence_disabled(self):
|
||||
"""
|
||||
|
@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
|
||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
|
||||
|
|
|
@ -41,14 +41,37 @@ class RestHelper:
|
|||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(
|
||||
self, room_creator=None, is_public=True, tok=None, expect_code=200,
|
||||
):
|
||||
self,
|
||||
room_creator: str = None,
|
||||
is_public: bool = True,
|
||||
room_version: str = None,
|
||||
tok: str = None,
|
||||
expect_code: int = 200,
|
||||
) -> str:
|
||||
"""
|
||||
Create a room.
|
||||
|
||||
Args:
|
||||
room_creator: The user ID to create the room with.
|
||||
is_public: If True, the `visibility` parameter will be set to the
|
||||
default (public). Otherwise, the `visibility` parameter will be set
|
||||
to "private".
|
||||
room_version: The room version to create the room as. Defaults to Synapse's
|
||||
default room version.
|
||||
tok: The access token to use in the request.
|
||||
expect_code: The expected HTTP response code.
|
||||
|
||||
Returns:
|
||||
The ID of the newly created room.
|
||||
"""
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = "/_matrix/client/r0/createRoom"
|
||||
content = {}
|
||||
if not is_public:
|
||||
content["visibility"] = "private"
|
||||
if room_version:
|
||||
content["room_version"] = room_version
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
|
|
|
@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
|||
return succeed(True)
|
||||
|
||||
|
||||
class DummyPasswordChecker(UserInteractiveAuthChecker):
|
||||
def check_auth(self, authdict, clientip):
|
||||
return succeed(authdict["identifier"]["user"])
|
||||
|
||||
|
||||
class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
|
@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
auth_handler = hs.get_auth_handler()
|
||||
auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
|
||||
|
||||
self.user_pass = "pass"
|
||||
self.user = self.register_user("test", self.user_pass)
|
||||
self.user_tok = self.login("test", self.user_pass)
|
||||
|
@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
def test_grandfathered_identifier(self):
|
||||
"""Check behaviour without "identifier" dict
|
||||
|
||||
Synapse used to require clients to submit a "user" field for m.login.password
|
||||
UIA - check that still works.
|
||||
"""
|
||||
|
||||
device_id = self.get_device_ids()[0]
|
||||
channel = self.delete_device(device_id, 401)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"user": self.user,
|
||||
"password": self.user_pass,
|
||||
"session": session,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_can_change_body(self):
|
||||
"""
|
||||
The client dict can be modified during the user interactive authentication session.
|
||||
|
|
|
@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
|||
tok = self.login("kermit", "monkey")
|
||||
# We need to manually add an email address otherwise the handler will do
|
||||
# nothing.
|
||||
now = self.hs.clock.time_msec()
|
||||
now = self.hs.get_clock().time_msec()
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(
|
||||
user_id=user_id,
|
||||
|
@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# We need to manually add an email address otherwise the handler will do
|
||||
# nothing.
|
||||
now = self.hs.clock.time_msec()
|
||||
now = self.hs.get_clock().time_msec()
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(
|
||||
user_id=user_id,
|
||||
|
@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
|
||||
|
||||
now_ms = self.hs.clock.time_msec()
|
||||
now_ms = self.hs.get_clock().time_msec()
|
||||
self.get_success(self.store._set_expiration_date_when_missing())
|
||||
|
||||
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
|
||||
|
|
|
@ -271,7 +271,7 @@ def setup_test_homeserver(
|
|||
|
||||
# Install @cache_in_self attributes
|
||||
for key, val in kwargs.items():
|
||||
setattr(hs, key, val)
|
||||
setattr(hs, "_" + key, val)
|
||||
|
||||
# Mock TLS
|
||||
hs.tls_server_context_factory = Mock()
|
||||
|
|
Loading…
Reference in New Issue