Merge branch 'release-v1.24.0' of github.com:matrix-org/synapse into matrix-org-hotfixes

michaelkaye/remove_warning
Patrick Cloke 2020-12-02 08:40:21 -05:00
commit 16744644f6
59 changed files with 1604 additions and 383 deletions

View File

@ -6,7 +6,7 @@
set -ex
apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"

1
changelog.d/8565.misc Normal file
View File

@ -0,0 +1 @@
Simplify the way the `HomeServer` object caches its internal attributes.

1
changelog.d/8799.bugfix Normal file
View File

@ -0,0 +1 @@
Allow per-room profiles to be used for the server notice user.

1
changelog.d/8800.misc Normal file
View File

@ -0,0 +1 @@
Add additional error checking for OpenID Connect and SAML mapping providers.

1
changelog.d/8804.feature Normal file
View File

@ -0,0 +1 @@
Allow Date header through CORS. Contributed by Nicolas Chamo.

1
changelog.d/8809.misc Normal file
View File

@ -0,0 +1 @@
Remove unnecessary function arguments and add typing to several membership replication classes.

1
changelog.d/8819.misc Normal file
View File

@ -0,0 +1 @@
Add tests for `password_auth_provider`s.

1
changelog.d/8820.feature Normal file
View File

@ -0,0 +1 @@
Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages".

1
changelog.d/8833.removal Normal file
View File

@ -0,0 +1 @@
Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir.

1
changelog.d/8835.bugfix Normal file
View File

@ -0,0 +1 @@
Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled.

1
changelog.d/8843.feature Normal file
View File

@ -0,0 +1 @@
Add `force_purge` option to delete-room admin api.

1
changelog.d/8845.misc Normal file
View File

@ -0,0 +1 @@
Drop redundant database index on `event_json`.

1
changelog.d/8847.misc Normal file
View File

@ -0,0 +1 @@
Simplify `uk.half-shot.msc2778.login.application_service` login handler.

1
changelog.d/8848.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication.

1
changelog.d/8849.misc Normal file
View File

@ -0,0 +1 @@
Refactor `password_auth_provider` support code.

1
changelog.d/8850.misc Normal file
View File

@ -0,0 +1 @@
Add missing `ordering` to background database updates.

1
changelog.d/8851.misc Normal file
View File

@ -0,0 +1 @@
Simplify the way the `HomeServer` object caches its internal attributes.

1
changelog.d/8854.misc Normal file
View File

@ -0,0 +1 @@
Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.

1
changelog.d/8855.feature Normal file
View File

@ -0,0 +1 @@
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.

View File

@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected.
The API is:
```json
```
POST /_synapse/admin/v1/rooms/<room_id>/delete
```
@ -439,6 +439,10 @@ The following JSON body parameters are available:
future attempts to join the room. Defaults to `false`.
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
Defaults to `true`.
* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it
will force a purge to go ahead even if there are local users still in the room. Do not
use this unless a regular `purge` operation fails, as it could leave those users'
clients in a confused state.
The JSON body must not be empty. The body must be at least `{}`.

View File

@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into
`__init__`.
This method should have the `@staticmethod` decoration.

View File

@ -2271,6 +2271,16 @@ push:
#
#include_content: false
# When a push notification is received, an unread count is also sent.
# This number can either be calculated as the number of unread messages
# for the user, or the number of *rooms* the user has unread messages in.
#
# The default value is "true", meaning push clients will see the number of
# rooms with unread messages in them. Uncomment to instead send the number
# of unread messages.
#
#group_unread_count_by_room: false
# Spam checkers are third-party modules that can block specific actions
# of local users, such as creating rooms and registering undesirable

View File

@ -80,6 +80,7 @@ files =
synapse/util/metrics.py,
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py

View File

@ -23,6 +23,9 @@ class PushConfig(Config):
def read_config(self, config, **kwargs):
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
self.push_group_unread_count_by_room = push_config.get(
"group_unread_count_by_room", True
)
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@ -68,4 +71,14 @@ class PushConfig(Config):
# include the event ID and room ID in push notification payloads.
#
#include_content: false
# When a push notification is received, an unread count is also sent.
# This number can either be calculated as the number of unread messages
# for the user, or the number of *rooms* the user has unread messages in.
#
# The default value is "true", meaning push clients will see the number of
# rooms with unread messages in them. Uncomment to instead send the number
# of unread messages.
#
#group_unread_count_by_room: false
"""

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = []
for module, config in hs.config.password_providers:
try:
self.password_providers.append(
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]
logger.info("Extra password_providers: %r", self.password_providers)
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
if self._password_enabled:
if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
if not self._password_enabled:
login_types.remove(LoginType.PASSWORD)
self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = await self.validate_login(user_id, authdict)
# fall back to the v1 login flow
canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@ -824,15 +830,157 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
self, username: str, login_submission: Dict[str, Any]
self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Also used by the user-interactive auth flow to validate auth types which don't
have an explicit UIA handler, including m.password.auth.
Args:
username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
# ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
#
# But the auth providers' check_auth interface requires a username, so in
# practice we can only support login methods which we can map to a username
# anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not isinstance(password, str):
raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
# map old-school login fields into new-school "identifier" fields.
identifier_dict = convert_client_dict_legacy_fields_to_identifier(
login_submission
)
# convert phone type identifiers to generic threepids
if identifier_dict["type"] == "m.id.phone":
identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
# convert threepid identifiers to user IDs
if identifier_dict["type"] == "m.id.thirdparty":
address = identifier_dict.get("address")
medium = identifier_dict.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
(medium, address), update=False
)
# Check for login providers that support 3pid login types
if login_type == LoginType.PASSWORD:
# we've already checked that there is a (valid) password field
assert isinstance(password, str)
(
canonical_user_id,
callback_3pid,
) = await self.check_password_provider_3pid(medium, address, password)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
return canonical_user_id, callback_3pid
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
(medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier_dict = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier_dict["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
username = identifier_dict.get("user")
if not username:
raise SynapseError(400, "User identifier is missing 'user' key")
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
self._failed_login_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
)
try:
return await self._validate_userid_login(username, login_submission)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
self._failed_login_attempts_ratelimiter.can_do_action(
qualified_user_id.lower()
)
raise
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
Args:
username: the username, from the identifier dict
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@ -843,38 +991,18 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
# we already checked that we have a valid login type
assert isinstance(login_type, str)
known_login_type = False
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not password:
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
is_valid = await provider.check_password(qualified_user_id, password)
if is_valid:
return qualified_user_id, None
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@ -899,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
password = login_submission["password"]
assert isinstance(password, str)
canonical_user_id = await self._check_local_password(
qualified_user_id, password # type: ignore
qualified_user_id, password
)
if canonical_user_id:
@ -938,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result
return None, None
@ -1008,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
if user_info.token_id is not None:
@ -1046,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@ -1374,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
self._supported_login_types = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
def __str__(self):
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result

View File

@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
self.hs.get_clock().time_msec()
+ self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(

View File

@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
return UserAttributes(**attributes)
async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
# If allowing existing users we want to generate a single localpart
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
previously_registered_user_id = next(iter(users))
elif user_id in users:
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, users
)
)
return previously_registered_user_id
return None
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
self._allow_existing_users,
grandfather_existing_users,
)

View File

@ -299,17 +299,22 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
async def purge_room(self, room_id: str, force: bool = False) -> None:
"""Purge the given room from the database.
Args:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
joined = await self.store.is_host_joined(room_id, self._server_name)
if joined:
raise SynapseError(400, "Users are still joined to this room")
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
if joined:
raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)

View File

@ -366,7 +366,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
if not self.allow_per_room_profiles or requester.shadow_banned:
# allow the server notices mxid to set room-level profile
is_requester_server_notices_user = (
self._server_notices_mxid is not None
and requester.user.to_string() == self._server_notices_mxid
)
if (
not self.allow_per_room_profiles and not is_requester_server_notices_user
) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.

View File

@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails"),
emails=result.get("emails", []),
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
return None
with (await self._mapping_lock.queue(self._auth_provider_id)):
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
grandfather_existing_users,
)
def expire_sessions(self):

View File

@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
allow_existing_users: bool = False,
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
if it has that matrix ID is returned regardless of the current mapping
logic.
If a callable is provided for grandfathering users, it is called and can
potentially return a matrix ID to use. If it does, the SSO ID is linked to
this matrix ID for subsequent calls.
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
If allow_existing_users is true the mapping function is only called once
and results in:
1. The use of a previously registered matrix ID. In this case, the
SSO ID is linked to the matrix ID. (Note it is possible that
other SSO IDs are linked to the same matrix ID.)
2. An unused localpart, in which case the user is registered (as
discussed above).
3. An error if the generated localpart matches multiple pre-existing
matrix IDs. Generally this should not happen.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
allow_existing_users: True if the localpart returned from the
mapping provider can be linked to an existing matrix ID.
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
Returns:
The user ID associated with the SSO response.
@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
if previously_registered_user_id:
return previously_registered_user_id
# Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
# Note, if allow_existing_users is true then the loop is guaranteed
# to end on the first iteration: either by matching an existing user,
# raising an error, or registering a new user. See the docstring for
# more in-depth an explanation.
if users and allow_existing_users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
previously_registered_user_id = next(iter(users))
elif user_id in users:
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, users
)
)
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
elif not users:
if not await self.store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:

View File

@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
request, error_code, error_dict, send_cors=True,
)
@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
pretty_print: Whether to include indentation and line-breaks in the
resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@ -615,13 +607,10 @@ def respond_with_json(
)
return None
if pretty_print:
encoder = iterencode_pretty_printed_json
if canonical_json:
encoder = iterencode_canonical_json
else:
if canonical_json:
encoder = iterencode_canonical_json
else:
encoder = _encode_json_bytes
encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
def _request_user_agent_is_curl(request: Request) -> bool:
user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
return False

View File

@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@ -140,7 +141,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
await self._send_badge(badge)
def on_timer(self):
@ -287,7 +292,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(
self.hs.get_datastore(),
self.user_id,
group_by_room=self._group_unread_count_by_room,
)
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:

View File

@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
async def get_badge_count(store, user_id):
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
if notifs["notify_count"] == 0:
continue
if group_by_room:
# return one badge count per conversation
badge += 1
else:
# increment the badge count by the number of unread messages in the room
badge += notifs["notify_count"]
return badge

View File

@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(
requester, room_id, user_id, remote_room_hosts, content
):
async def _serialize_payload( # type: ignore
requester: Requester,
room_id: str,
user_id: str,
remote_room_hosts: List[str],
content: JsonDict,
) -> JsonDict:
"""
Args:
requester(Requester)
room_id (str)
user_id (str)
remote_room_hosts (list[str]): Servers to try and join via
content(dict): The event content to use for the join event
requester: The user making the request according to the access token
room_id: The ID of the room.
user_id: The ID of the user.
remote_room_hosts: Servers to try and join via
content: The event content to use for the join event
Returns:
A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request(self, request, room_id, user_id):
async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
) -> JsonDict:
"""
Args:
invite_event_id: ID of the invite to be rejected
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
invite_event_id: The ID of the invite to be rejected.
txn_id: Optional transaction ID supplied by the client
requester: User making the rejection request, according to the access token
content: Additional content to include in the rejection event.
Normally an empty dict.
Returns:
A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
async def _handle_request(self, request, invite_event_id):
async def _handle_request( # type: ignore
self, request: Request, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
async def _serialize_payload(room_id, user_id, change):
async def _serialize_payload( # type: ignore
room_id: str, user_id: str, change: str
) -> JsonDict:
"""
Args:
room_id (str)
user_id (str)
change (str): "left"
room_id: The ID of the room.
user_id: The ID of the user.
change: "left"
Returns:
A dict representing the payload of the request.
"""
assert change == "left"
return {}
def _handle_request(self, request, room_id, user_id, change):
def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)

View File

@ -70,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
"""Delete a room from server. It is a combination and improvement of
shut down and purge room.
"""Delete a room from server.
It is a combination and improvement of shutdown and purge room.
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
If desired any local aliases will be repointed to a new room
created by `new_room_user_id` and kicked users will be auto
created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
It will remove all trace of a room from the database.
If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@ -110,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
force_purge = content.get("force_purge", False)
if not isinstance(force_purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'force_purge' must be a boolean, if given",
Codes.BAD_JSON,
)
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@ -121,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
await self.pagination_handler.purge_room(room_id)
await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)

View File

@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
from synapse.handlers.auth import (
convert_client_dict_legacy_fields_to_identifier,
login_id_phone_to_thirdparty,
)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
self._failed_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request: SynapseRequest):
flows = []
@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
def _get_qualified_user_id(self, identifier):
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
if identifier["user"].startswith("@"):
return identifier["user"]
else:
return UserID(identifier["user"], self.hs.hostname).to_string()
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
logger.info(
"Got appservice login request with identifier: %r",
login_submission.get("identifier"),
)
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
qualified_user_id = self._get_qualified_user_id(identifier)
if not isinstance(identifier, dict):
raise SynapseError(
400, "Invalid identifier in login submission", Codes.INVALID_PARAM
)
# this login flow only supports identifiers of type "m.id.user".
if identifier.get("type") != "m.id.user":
raise SynapseError(
400, "Unknown login identifier type", Codes.INVALID_PARAM
)
user = identifier.get("user")
if not isinstance(user, str):
raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
if user.startswith("@"):
qualified_user_id = user
else:
qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_phone_to_thirdparty(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Check for login providers that support 3pid login types
(
canonical_user_id,
callback_3pid,
) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
qualified_user_id = self._get_qualified_user_id(identifier)
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
canonical_user_id, callback = await self.auth_handler.validate_login(
login_submission, ratelimit=True
)
try:
canonical_user_id, callback = await self.auth_handler.validate_login(
identifier["user"], login_submission
)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
raise
result = await self._complete_login(
canonical_user_id, login_submission, callback
)

View File

@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)

View File

@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(

View File

@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
self.clock = hs.clock
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)

View File

@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
depname = builder.__name__[len("get_") :]
# get_attr -> _attr
depname = builder.__name__[len("get") :]
building = [False]
@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
self.clock = Clock(reactor)
self.distributor = Distributor()
self.registration_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count,
)
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
@cache_in_self
def get_clock(self) -> Clock:
return self.clock
return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
@cache_in_self
def get_distributor(self) -> Distributor:
return self.distributor
return Distributor()
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
return Ratelimiter(
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
)
@cache_in_self
def get_federation_client(self) -> FederationClient:
@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:

View File

@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
"event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
"event_push_actions",
"event_search",
"events",

View File

@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
INSERT INTO background_updates (update_name, progress_json) VALUES
('local_media_repository_thumbnails_method_idx', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');

View File

@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
INSERT INTO background_updates (update_name, progress_json) VALUES
('populate_stats_process_rooms_2', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5812, 'populate_stats_process_rooms_2', '{}');

View File

@ -1,2 +1,2 @@
INSERT INTO background_updates (update_name, progress_json) VALUES
('users_have_local_media', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5822, 'users_have_local_media', '{}');

View File

@ -13,5 +13,5 @@
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('e2e_cross_signing_keys_idx', '{}');
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5823, 'e2e_cross_signing_keys_idx', '{}');

View File

@ -0,0 +1,19 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- this index is essentially redundant. The only time it was ever used was when purging
-- rooms - and Synapse 1.24 will change that.
DROP INDEX IF EXISTS event_json_room_id;

View File

@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
self.hs.clock.now = 5000
self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)

View File

@ -23,7 +23,7 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
@ -127,13 +127,8 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
config = self.default_config()
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {
"enabled": True,
@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
http_client=self.http_client,
proxied_http_client=self.http_client,
config=config,
)
return config
self.handler = OidcHandler(hs)
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
sso_handler.render_error = self.render_error
# Reduce the number of attempts when generating MXIDs.
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
# in this case we'll just use the same username. A more realistic example
@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential users for a particular username.
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)

View File

@ -0,0 +1,580 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the password_auth_provider interface"""
from typing import Any, Type, Union
from mock import Mock
from twisted.internet import defer
import synapse
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import devices
from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
mock_password_provider = Mock()
class PasswordOnlyAuthProvider:
"""A password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def check_password(self, *args):
return mock_password_provider.check_password(*args)
class CustomAuthProvider:
"""A password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class PasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, account_handler):
pass
def get_supported_login_types(self):
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
]
def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_login(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
# try a weird username / pass. Honestly it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
)
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
module_api = self.hs.get_module_api()
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Make another request providing the UI auth flow.
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_login(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Wrong password
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "xxx"
)
mock_password_provider.reset_mock()
# Right password
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login(self):
"""localdb_enabled can block login with the local password
"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth(self):
"""localdb_enabled can block ui auth with the local password
"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
self.login("localuser", "p", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS,
)
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
mock_password_provider.reset_mock()
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
"@ MALFORMED! :bz"
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
# missing param
body = {
"auth": {
"type": "test.login_type",
"identifier": {"type": "m.id.user", "user": "localuser"},
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
# use it...
self.assertIn("Missing parameters", channel.json_body["error"])
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
# check the args to the callback
callback.assert_called_once()
call_args, call_kwargs = callback.call_args
# should be one positional arg
self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["user_id"], "@user:bz")
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected. In particular, "password" should *not*
# be present.
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
mock_password_provider.reset_mock()
# check that auth with password is rejected
body = {
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "localuser"},
"password": "localpass",
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
self.assertEqual(
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
body["auth"]["type"] = "test.login_type"
body["auth"]["test_field"] = "x"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
def _get_login_flows(self) -> JsonDict:
_, channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
_, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
return channel.json_body["session"]
def _authed_delete_device(
self,
access_token: str,
device_id: str,
session: str,
user_id: str,
password: str,
) -> FakeChannel:
"""Make a delete device request, authenticating with the given uid/password"""
return self._delete_device(
access_token,
device_id,
{
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
"password": password,
"session": session,
},
},
)
def _delete_device(
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
_, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel

168
tests/handlers/test_saml.py Normal file
View File

@ -0,0 +1,168 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attr
from synapse.handlers.sso import MappingException
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
class TestMappingProvider:
def __init__(self, config, module):
pass
@staticmethod
def parse_config(config):
return
@staticmethod
def get_saml_attributes(config):
return {"uid"}, {"displayName"}
def get_remote_user_id(self, saml_response, client_redirect_url):
return saml_response.ava["uid"]
def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}
class SamlHandlerTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
saml_config = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
# override_config works as expected.
saml_config.update(config.get("saml2_config", {}))
config["saml2_config"] = saml_config
return config
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
# Reduce the number of attempts when generating MXIDs.
sso_handler = hs.get_sso_handler()
sso_handler._MAP_USERNAME_RETRIES = 3
return hs
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
# The redirect_url doesn't matter with the default user mapping provider.
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
"""Existing users can log in with SAML account."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
redirect_url = ""
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = ""
mxid = self.get_success(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential mxids for a particular SAML username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)
for i in range(1, 3):
self.get_success(
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
)
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
e = self.get_failure(
self.handler._map_saml_response_to_user(
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
)

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet.defer import Deferred
@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import receipts
from tests.unittest import HomeserverTestCase
from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
receipts.register_servlets,
]
user_id = True
hijack_auth = False
@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
def test_push_unread_count_group_by_room(self):
"""
The HTTP pusher will group unread count by number of unread rooms.
"""
# Carry out common push count tests and setup
self._test_push_unread_count()
# Carry out our option-value specific test
#
# This push should still only contain an unread count of 1 (for 1 unread room)
self.assertEqual(
self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
)
@override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self):
"""
The HTTP pusher will send the total unread message count.
"""
# Carry out common push count tests and setup
self._test_push_unread_count()
# Carry out our option-value specific test
#
# We're counting every unread message, so there should now be 4 since the
# last read receipt
self.assertEqual(
self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
)
def _test_push_unread_count(self):
"""
Tests that the correct unread count appears in sent push notifications
Note that:
* Sending messages will cause push notifications to go out to relevant users
* Sending a read receipt will cause a "badge update" notification to go out to
the user that sent the receipt
"""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Register the user who sends the message
other_user_id = self.register_user("other_user", "pass")
other_access_token = self.login("other_user", "pass")
# Create a room (as other_user)
room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
# The user to get notified joins
self.helper.join(room=room_id, user=user_id, tok=access_token)
# Register the pusher
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
)
)
# Send a message
response = self.helper.send(
room_id, body="Hello there!", tok=other_access_token
)
# To get an unread count, the user who is getting notified has to have a read
# position in the room. We'll set the read position to this event in a moment
first_message_event_id = response["event_id"]
# Advance time a bit (so the pusher will register something has happened) and
# make the push succeed
self.push_attempts[0][0].callback({})
self.pump()
# Check our push made it
self.assertEqual(len(self.push_attempts), 1)
self.assertEqual(self.push_attempts[0][1], "example.com")
# Check that the unread count for the room is 0
#
# The unread count is zero as the user has no read receipt in the room yet
self.assertEqual(
self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
)
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
# even if the user does not receive another message, their unread
# count goes down
request, channel = self.make_request(
"POST",
"/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
{},
access_token=access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Advance time and make the push succeed
self.push_attempts[1][0].callback({})
self.pump()
# Unread count is still zero as we've read the only message in the room
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(
self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
)
# Send another message
self.helper.send(
room_id, body="How's the weather today?", tok=other_access_token
)
# Advance time and make the push succeed
self.push_attempts[2][0].callback({})
self.pump()
# This push should contain an unread count of 1 as there's now been one
# message since our last read receipt
self.assertEqual(len(self.push_attempts), 3)
self.assertEqual(
self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
)
# Since we're grouping by room, sending more messages shouldn't increase the
# unread count, as they're all being sent in the same room
self.helper.send(room_id, body="Hello?", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[3][0].callback({})
self.helper.send(room_id, body="Hello??", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[4][0].callback({})
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
# Advance time and make the push succeed
self.pump()
self.push_attempts[5][0].callback({})
self.assertEqual(len(self.push_attempts), 6)

View File

@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(

View File

@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
self.clock = hs.clock
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")

View File

@ -33,12 +33,15 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"red", http_client=None, federation_client=Mock()
)
presence_handler = Mock()
presence_handler.set_state.return_value = defer.succeed(None)
hs.presence_handler = Mock()
hs.presence_handler.set_state.return_value = defer.succeed(None)
hs = self.setup_test_homeserver(
"red",
http_client=None,
federation_client=Mock(),
presence_handler=presence_handler,
)
return hs
@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)

View File

@ -41,14 +41,37 @@ class RestHelper:
auth_user_id = attr.ib()
def create_room_as(
self, room_creator=None, is_public=True, tok=None, expect_code=200,
):
self,
room_creator: str = None,
is_public: bool = True,
room_version: str = None,
tok: str = None,
expect_code: int = 200,
) -> str:
"""
Create a room.
Args:
room_creator: The user ID to create the room with.
is_public: If True, the `visibility` parameter will be set to the
default (public). Otherwise, the `visibility` parameter will be set
to "private".
room_version: The room version to create the room as. Defaults to Synapse's
default room version.
tok: The access token to use in the request.
expect_code: The expected HTTP response code.
Returns:
The ID of the newly created room.
"""
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
if room_version:
content["room_version"] = room_version
if tok:
path = path + "?access_token=%s" % tok

View File

@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
class DummyPasswordChecker(UserInteractiveAuthChecker):
def check_auth(self, authdict, clientip):
return succeed(authdict["identifier"]["user"])
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
def test_grandfathered_identifier(self):
"""Check behaviour without "identifier" dict
Synapse used to require clients to submit a "user" field for m.login.password
UIA - check that still works.
"""
device_id = self.get_device_ids()[0]
channel = self.delete_device(device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
device_id,
200,
{
"auth": {
"type": "m.login.password",
"user": self.user,
"password": self.user_pass,
"session": session,
},
},
)
def test_can_change_body(self):
"""
The client dict can be modified during the user interactive authentication session.

View File

@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
now = self.hs.clock.time_msec()
now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do
# nothing.
now = self.hs.clock.time_msec()
now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
now_ms = self.hs.clock.time_msec()
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))

View File

@ -271,7 +271,7 @@ def setup_test_homeserver(
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, key, val)
setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()