Compare commits
4 Commits
4d9496559d
...
8388384a64
| Author | SHA1 | Date |
|---|---|---|
|
|
8388384a64 | |
|
|
c21bdc813f | |
|
|
d3ed93504b | |
|
|
edb3d3f827 |
|
|
@ -6,7 +6,7 @@
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
|
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
|
||||||
|
|
||||||
export LANG="C.UTF-8"
|
export LANG="C.UTF-8"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Add additional error checking for OpenID Connect and SAML mapping providers.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Refactor `password_auth_provider` support code.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
|
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
@ -25,6 +26,7 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
|
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
|
||||||
# better way to break the loop
|
# better way to break the loop
|
||||||
account_handler = ModuleApi(hs, self)
|
account_handler = ModuleApi(hs, self)
|
||||||
|
|
||||||
self.password_providers = []
|
self.password_providers = [
|
||||||
for module, config in hs.config.password_providers:
|
PasswordProvider.load(module, config, account_handler)
|
||||||
try:
|
for module, config in hs.config.password_providers
|
||||||
self.password_providers.append(
|
]
|
||||||
module(config=config, account_handler=account_handler)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error while initializing %r: %s", module, e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info("Extra password_providers: %r", self.password_providers)
|
logger.info("Extra password_providers: %s", self.password_providers)
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
login_type = login_submission.get("type")
|
login_type = login_submission.get("type")
|
||||||
|
if not isinstance(login_type, str):
|
||||||
|
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
|
||||||
|
|
||||||
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
# ideally, we wouldn't be checking the identifier unless we know we have a login
|
||||||
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
|
||||||
|
|
@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
|
||||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
login_type = login_submission.get("type")
|
login_type = login_submission.get("type")
|
||||||
|
# we already checked that we have a valid login type
|
||||||
|
assert isinstance(login_type, str)
|
||||||
|
|
||||||
known_login_type = False
|
known_login_type = False
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
|
|
||||||
known_login_type = True
|
|
||||||
# we've already checked that there is a (valid) password field
|
|
||||||
is_valid = await provider.check_password(
|
|
||||||
qualified_user_id, login_submission["password"]
|
|
||||||
)
|
|
||||||
if is_valid:
|
|
||||||
return qualified_user_id, None
|
|
||||||
|
|
||||||
if not hasattr(provider, "get_supported_login_types") or not hasattr(
|
|
||||||
provider, "check_auth"
|
|
||||||
):
|
|
||||||
# this password provider doesn't understand custom login types
|
|
||||||
continue
|
|
||||||
|
|
||||||
supported_login_types = provider.get_supported_login_types()
|
supported_login_types = provider.get_supported_login_types()
|
||||||
if login_type not in supported_login_types:
|
if login_type not in supported_login_types:
|
||||||
# this password provider doesn't understand this login type
|
# this password provider doesn't understand this login type
|
||||||
|
|
@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
result = await provider.check_auth(username, login_type, login_dict)
|
result = await provider.check_auth(username, login_type, login_dict)
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, str):
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||||
|
|
@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler):
|
||||||
unsuccessful, `user_id` and `callback` are both `None`.
|
unsuccessful, `user_id` and `callback` are both `None`.
|
||||||
"""
|
"""
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "check_3pid_auth"):
|
result = await provider.check_3pid_auth(medium, address, password)
|
||||||
# This function is able to return a deferred that either
|
if result:
|
||||||
# resolves None, meaning authentication failure, or upon
|
return result
|
||||||
# success, to a str (which is the user_id) or a tuple of
|
|
||||||
# (user_id, callback_func), where callback_func should be run
|
|
||||||
# after we've finished everything else
|
|
||||||
result = await provider.check_3pid_auth(medium, address, password)
|
|
||||||
if result:
|
|
||||||
# Check if the return value is a str or a tuple
|
|
||||||
if isinstance(result, str):
|
|
||||||
# If it's a str, set callback function to None
|
|
||||||
result = (result, None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
await provider.on_logged_out(
|
||||||
# This might return an awaitable, if it does block the log out
|
user_id=user_info.user_id,
|
||||||
# until it completes.
|
device_id=user_info.device_id,
|
||||||
result = provider.on_logged_out(
|
access_token=access_token,
|
||||||
user_id=user_info.user_id,
|
)
|
||||||
device_id=user_info.device_id,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
await result
|
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info.token_id is not None:
|
if user_info.token_id is not None:
|
||||||
|
|
@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
for token, token_id, device_id in tokens_and_devices:
|
||||||
for token, token_id, device_id in tokens_and_devices:
|
await provider.on_logged_out(
|
||||||
await provider.on_logged_out(
|
user_id=user_id, device_id=device_id, access_token=token
|
||||||
user_id=user_id, device_id=device_id, access_token=token
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# delete pushers associated with the access tokens
|
# delete pushers associated with the access tokens
|
||||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
|
|
@ -1519,3 +1488,127 @@ class MacaroonGenerator:
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordProvider:
|
||||||
|
"""Wrapper for a password auth provider module
|
||||||
|
|
||||||
|
This class abstracts out all of the backwards-compatibility hacks for
|
||||||
|
password providers, to provide a consistent interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||||
|
try:
|
||||||
|
pp = module(config=config, account_handler=module_api)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error while initializing %r: %s", module, e)
|
||||||
|
raise
|
||||||
|
return cls(pp, module_api)
|
||||||
|
|
||||||
|
def __init__(self, pp, module_api: ModuleApi):
|
||||||
|
self._pp = pp
|
||||||
|
self._module_api = module_api
|
||||||
|
|
||||||
|
self._supported_login_types = {}
|
||||||
|
|
||||||
|
# grandfather in check_password support
|
||||||
|
if hasattr(self._pp, "check_password"):
|
||||||
|
self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
||||||
|
|
||||||
|
g = getattr(self._pp, "get_supported_login_types", None)
|
||||||
|
if g:
|
||||||
|
self._supported_login_types.update(g())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self._pp)
|
||||||
|
|
||||||
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||||
|
"""Get the login types supported by this password provider
|
||||||
|
|
||||||
|
Returns a map from a login type identifier (such as m.login.password) to an
|
||||||
|
iterable giving the fields which must be provided by the user in the submission
|
||||||
|
to the /login API.
|
||||||
|
|
||||||
|
This wrapper adds m.login.password to the list if the underlying password
|
||||||
|
provider supports the check_password() api.
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
async def check_auth(
|
||||||
|
self, username: str, login_type: str, login_dict: JsonDict
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
"""Check if the user has presented valid login credentials
|
||||||
|
|
||||||
|
This wrapper also calls check_password() if the underlying password provider
|
||||||
|
supports the check_password() api and the login type is m.login.password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: user id presented by the client. Either an MXID or an unqualified
|
||||||
|
username.
|
||||||
|
|
||||||
|
login_type: the login type being attempted - one of the types returned by
|
||||||
|
get_supported_login_types()
|
||||||
|
|
||||||
|
login_dict: the dictionary of login secrets passed by the client.
|
||||||
|
|
||||||
|
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
|
||||||
|
user, and `callback` is an optional callback which will be called with the
|
||||||
|
result from the /login call (including access_token, device_id, etc.)
|
||||||
|
"""
|
||||||
|
# first grandfather in a call to check_password
|
||||||
|
if login_type == LoginType.PASSWORD:
|
||||||
|
g = getattr(self._pp, "check_password", None)
|
||||||
|
if g:
|
||||||
|
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||||
|
is_valid = await self._pp.check_password(
|
||||||
|
qualified_user_id, login_dict["password"]
|
||||||
|
)
|
||||||
|
if is_valid:
|
||||||
|
return qualified_user_id, None
|
||||||
|
|
||||||
|
g = getattr(self._pp, "check_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
result = await g(username, login_type, login_dict)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def check_3pid_auth(
|
||||||
|
self, medium: str, address: str, password: str
|
||||||
|
) -> Optional[Tuple[str, Optional[Callable]]]:
|
||||||
|
g = getattr(self._pp, "check_3pid_auth", None)
|
||||||
|
if not g:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# This function is able to return a deferred that either
|
||||||
|
# resolves None, meaning authentication failure, or upon
|
||||||
|
# success, to a str (which is the user_id) or a tuple of
|
||||||
|
# (user_id, callback_func), where callback_func should be run
|
||||||
|
# after we've finished everything else
|
||||||
|
result = await g(medium, address, password)
|
||||||
|
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
return result, None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def on_logged_out(
|
||||||
|
self, user_id: str, device_id: Optional[str], access_token: str
|
||||||
|
) -> None:
|
||||||
|
g = getattr(self._pp, "on_logged_out", None)
|
||||||
|
if not g:
|
||||||
|
return
|
||||||
|
|
||||||
|
# This might return an awaitable, if it does block the log out
|
||||||
|
# until it completes.
|
||||||
|
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
|
||||||
|
|
||||||
return UserAttributes(**attributes)
|
return UserAttributes(**attributes)
|
||||||
|
|
||||||
|
async def grandfather_existing_users() -> Optional[str]:
|
||||||
|
if self._allow_existing_users:
|
||||||
|
# If allowing existing users we want to generate a single localpart
|
||||||
|
# and attempt to match it.
|
||||||
|
attributes = await oidc_response_to_user_attributes(failures=0)
|
||||||
|
|
||||||
|
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||||
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
if users:
|
||||||
|
# If an existing matrix ID is returned, then use it.
|
||||||
|
if len(users) == 1:
|
||||||
|
previously_registered_user_id = next(iter(users))
|
||||||
|
elif user_id in users:
|
||||||
|
previously_registered_user_id = user_id
|
||||||
|
else:
|
||||||
|
# Do not attempt to continue generating Matrix IDs.
|
||||||
|
raise MappingException(
|
||||||
|
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||||
|
user_id, users
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
user_agent,
|
||||||
ip_address,
|
ip_address,
|
||||||
oidc_response_to_user_attributes,
|
oidc_response_to_user_attributes,
|
||||||
self._allow_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
|
||||||
return UserAttributes(
|
return UserAttributes(
|
||||||
localpart=result.get("mxid_localpart"),
|
localpart=result.get("mxid_localpart"),
|
||||||
display_name=result.get("displayname"),
|
display_name=result.get("displayname"),
|
||||||
emails=result.get("emails"),
|
emails=result.get("emails", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
async def grandfather_existing_users() -> Optional[str]:
|
||||||
# backwards-compatibility hack: see if there is an existing user with a
|
# backwards-compatibility hack: see if there is an existing user with a
|
||||||
# suitable mapping from the uid
|
# suitable mapping from the uid
|
||||||
if (
|
if (
|
||||||
|
|
@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
|
||||||
if users:
|
if users:
|
||||||
registered_user_id = list(users.keys())[0]
|
registered_user_id = list(users.keys())[0]
|
||||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||||
await self.store.record_user_external_id(
|
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
|
||||||
)
|
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
user_agent,
|
||||||
ip_address,
|
ip_address,
|
||||||
saml_response_to_remapped_user_attributes,
|
saml_response_to_remapped_user_attributes,
|
||||||
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
|
||||||
user_agent: str,
|
user_agent: str,
|
||||||
ip_address: str,
|
ip_address: str,
|
||||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
allow_existing_users: bool = False,
|
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||||
|
|
@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
|
||||||
if it has that matrix ID is returned regardless of the current mapping
|
if it has that matrix ID is returned regardless of the current mapping
|
||||||
logic.
|
logic.
|
||||||
|
|
||||||
|
If a callable is provided for grandfathering users, it is called and can
|
||||||
|
potentially return a matrix ID to use. If it does, the SSO ID is linked to
|
||||||
|
this matrix ID for subsequent calls.
|
||||||
|
|
||||||
The mapping function is called (potentially multiple times) to generate
|
The mapping function is called (potentially multiple times) to generate
|
||||||
a localpart for the user.
|
a localpart for the user.
|
||||||
|
|
||||||
|
|
@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
|
||||||
given user-agent and IP address and the SSO ID is linked to this matrix
|
given user-agent and IP address and the SSO ID is linked to this matrix
|
||||||
ID for subsequent calls.
|
ID for subsequent calls.
|
||||||
|
|
||||||
If allow_existing_users is true the mapping function is only called once
|
|
||||||
and results in:
|
|
||||||
|
|
||||||
1. The use of a previously registered matrix ID. In this case, the
|
|
||||||
SSO ID is linked to the matrix ID. (Note it is possible that
|
|
||||||
other SSO IDs are linked to the same matrix ID.)
|
|
||||||
2. An unused localpart, in which case the user is registered (as
|
|
||||||
discussed above).
|
|
||||||
3. An error if the generated localpart matches multiple pre-existing
|
|
||||||
matrix IDs. Generally this should not happen.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
"oidc" or "saml".
|
"oidc" or "saml".
|
||||||
|
|
@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
|
||||||
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
||||||
The only parameter is an integer which represents the amount of
|
The only parameter is an integer which represents the amount of
|
||||||
times the returned mxid localpart mapping has failed.
|
times the returned mxid localpart mapping has failed.
|
||||||
allow_existing_users: True if the localpart returned from the
|
grandfather_existing_users: A callable which can return an previously
|
||||||
mapping provider can be linked to an existing matrix ID.
|
existing matrix ID. The SSO ID is then linked to the returned
|
||||||
|
matrix ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user ID associated with the SSO response.
|
The user ID associated with the SSO response.
|
||||||
|
|
@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
|
||||||
if previously_registered_user_id:
|
if previously_registered_user_id:
|
||||||
return previously_registered_user_id
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
# Check for grandfathering of users.
|
||||||
|
if grandfather_existing_users:
|
||||||
|
previously_registered_user_id = await grandfather_existing_users()
|
||||||
|
if previously_registered_user_id:
|
||||||
|
# Future logins should also match this user ID.
|
||||||
|
await self.store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
|
)
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
# Otherwise, generate a new user.
|
# Otherwise, generate a new user.
|
||||||
for i in range(self._MAP_USERNAME_RETRIES):
|
for i in range(self._MAP_USERNAME_RETRIES):
|
||||||
try:
|
try:
|
||||||
|
|
@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
||||||
# Note, if allow_existing_users is true then the loop is guaranteed
|
|
||||||
# to end on the first iteration: either by matching an existing user,
|
|
||||||
# raising an error, or registering a new user. See the docstring for
|
|
||||||
# more in-depth an explanation.
|
|
||||||
if users and allow_existing_users:
|
|
||||||
# If an existing matrix ID is returned, then use it.
|
|
||||||
if len(users) == 1:
|
|
||||||
previously_registered_user_id = next(iter(users))
|
|
||||||
elif user_id in users:
|
|
||||||
previously_registered_user_id = user_id
|
|
||||||
else:
|
|
||||||
# Do not attempt to continue generating Matrix IDs.
|
|
||||||
raise MappingException(
|
|
||||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
|
||||||
user_id, users
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Future logins should also match this user ID.
|
|
||||||
await self.store.record_user_external_id(
|
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
elif not users:
|
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import pymacaroons
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
|
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
|
@ -127,13 +127,8 @@ async def get_json(url):
|
||||||
|
|
||||||
|
|
||||||
class OidcHandlerTestCase(HomeserverTestCase):
|
class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
self.http_client = Mock(spec=["get_json"])
|
|
||||||
self.http_client.get_json.side_effect = get_json
|
|
||||||
self.http_client.user_agent = "Synapse Test"
|
|
||||||
|
|
||||||
config = self.default_config()
|
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
oidc_config = {
|
oidc_config = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
|
|
@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
oidc_config.update(config.get("oidc_config", {}))
|
oidc_config.update(config.get("oidc_config", {}))
|
||||||
config["oidc_config"] = oidc_config
|
config["oidc_config"] = oidc_config
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
return config
|
||||||
http_client=self.http_client,
|
|
||||||
proxied_http_client=self.http_client,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.handler = OidcHandler(hs)
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
self.http_client = Mock(spec=["get_json"])
|
||||||
|
self.http_client.get_json.side_effect = get_json
|
||||||
|
self.http_client.user_agent = "Synapse Test"
|
||||||
|
|
||||||
|
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||||
|
|
||||||
|
self.handler = hs.get_oidc_handler()
|
||||||
|
sso_handler = hs.get_sso_handler()
|
||||||
# Mock the render error method.
|
# Mock the render error method.
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
self.handler._sso_handler.render_error = self.render_error
|
sso_handler.render_error = self.render_error
|
||||||
|
|
||||||
# Reduce the number of attempts when generating MXIDs.
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
|
@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(mxid, "@test_user:test")
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
# Subsequent calls should map to the same mxid.
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_userinfo_to_user(
|
||||||
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
||||||
# requires a unique sub, but something that maps to the same matrix ID,
|
# requires a unique sub, but something that maps to the same matrix ID,
|
||||||
# in this case we'll just use the same username. A more realistic example
|
# in this case we'll just use the same username. A more realistic example
|
||||||
|
|
@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
self.assertEqual(mxid, "@test_user1:test")
|
self.assertEqual(mxid, "@test_user1:test")
|
||||||
|
|
||||||
# Register all of the potential users for a particular username.
|
# Register all of the potential mxids for a particular OIDC username.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
store.register_user(user_id="@tester:test", password_hash=None)
|
store.register_user(user_id="@tester:test", password_hash=None)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# first delete should give a 401
|
# first delete should give a 401
|
||||||
channel = self._delete_device(tok1, "dev2")
|
channel = self._delete_device(tok1, "dev2")
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
# there are no valid flows here!
|
# m.login.password UIA is permitted because the auth provider allows it,
|
||||||
self.assertEqual(channel.json_body["flows"], [])
|
# even though the localdb does not.
|
||||||
|
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from synapse.handlers.sso import MappingException
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
# These are a few constants that are used as config parameters in the tests.
|
||||||
|
BASE_URL = "https://synapse/"
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class FakeAuthnResponse:
|
||||||
|
ava = attr.ib(type=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingProvider:
|
||||||
|
def __init__(self, config, module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config):
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_saml_attributes(config):
|
||||||
|
return {"uid"}, {"displayName"}
|
||||||
|
|
||||||
|
def get_remote_user_id(self, saml_response, client_redirect_url):
|
||||||
|
return saml_response.ava["uid"]
|
||||||
|
|
||||||
|
def saml_response_to_user_attributes(
|
||||||
|
self, saml_response, failures, client_redirect_url
|
||||||
|
):
|
||||||
|
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
|
||||||
|
return {"mxid_localpart": localpart, "displayname": None}
|
||||||
|
|
||||||
|
|
||||||
|
class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
config["public_baseurl"] = BASE_URL
|
||||||
|
saml_config = {
|
||||||
|
"sp_config": {"metadata": {}},
|
||||||
|
# Disable grandfathering.
|
||||||
|
"grandfathered_mxid_source_attribute": None,
|
||||||
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update this config with what's in the default config so that
|
||||||
|
# override_config works as expected.
|
||||||
|
saml_config.update(config.get("saml2_config", {}))
|
||||||
|
config["saml2_config"] = saml_config
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
|
self.handler = hs.get_saml_handler()
|
||||||
|
|
||||||
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
|
sso_handler = hs.get_sso_handler()
|
||||||
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def test_map_saml_response_to_user(self):
|
||||||
|
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
# The redirect_url doesn't matter with the default user mapping provider.
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
|
def test_map_saml_response_to_existing_user(self):
|
||||||
|
"""Existing users can log in with SAML account."""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map a user via SSO.
|
||||||
|
saml_response = FakeAuthnResponse(
|
||||||
|
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
|
||||||
|
)
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
# Subsequent calls should map to the same mxid.
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
||||||
|
redirect_url = ""
|
||||||
|
e = self.get_failure(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
),
|
||||||
|
MappingException,
|
||||||
|
)
|
||||||
|
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||||
|
|
||||||
|
def test_map_saml_response_to_user_retries(self):
|
||||||
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||||
|
)
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
|
||||||
|
redirect_url = ""
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
|
self.assertEqual(mxid, "@test_user1:test")
|
||||||
|
|
||||||
|
# Register all of the potential mxids for a particular SAML username.
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester:test", password_hash=None)
|
||||||
|
)
|
||||||
|
for i in range(1, 3):
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attempt to map to a username, this will fail since all potential usernames are taken.
|
||||||
|
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
|
||||||
|
e = self.get_failure(
|
||||||
|
self.handler._map_saml_response_to_user(
|
||||||
|
saml_response, redirect_url, "user-agent", "10.10.10.10"
|
||||||
|
),
|
||||||
|
MappingException,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||||
|
)
|
||||||
|
|
@ -41,14 +41,37 @@ class RestHelper:
|
||||||
auth_user_id = attr.ib()
|
auth_user_id = attr.ib()
|
||||||
|
|
||||||
def create_room_as(
|
def create_room_as(
|
||||||
self, room_creator=None, is_public=True, tok=None, expect_code=200,
|
self,
|
||||||
):
|
room_creator: str = None,
|
||||||
|
is_public: bool = True,
|
||||||
|
room_version: str = None,
|
||||||
|
tok: str = None,
|
||||||
|
expect_code: int = 200,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_creator: The user ID to create the room with.
|
||||||
|
is_public: If True, the `visibility` parameter will be set to the
|
||||||
|
default (public). Otherwise, the `visibility` parameter will be set
|
||||||
|
to "private".
|
||||||
|
room_version: The room version to create the room as. Defaults to Synapse's
|
||||||
|
default room version.
|
||||||
|
tok: The access token to use in the request.
|
||||||
|
expect_code: The expected HTTP response code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ID of the newly created room.
|
||||||
|
"""
|
||||||
temp_id = self.auth_user_id
|
temp_id = self.auth_user_id
|
||||||
self.auth_user_id = room_creator
|
self.auth_user_id = room_creator
|
||||||
path = "/_matrix/client/r0/createRoom"
|
path = "/_matrix/client/r0/createRoom"
|
||||||
content = {}
|
content = {}
|
||||||
if not is_public:
|
if not is_public:
|
||||||
content["visibility"] = "private"
|
content["visibility"] = "private"
|
||||||
|
if room_version:
|
||||||
|
content["room_version"] = room_version
|
||||||
if tok:
|
if tok:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue