Compare commits

...

4 Commits

Author SHA1 Message Date
Patrick Cloke 8388384a64
Fix a regression when grandfathering SAML users. (#8855)
This was broken in #8801 when abstracting code shared with OIDC.

After this change both SAML and OIDC have a concept of
grandfathering users, but with different implementations.
2020-12-02 07:45:42 -05:00
Patrick Cloke c21bdc813f
Add basic SAML tests for mapping users. (#8800) 2020-12-02 07:09:21 -05:00
Richard van der Hoff d3ed93504b
Create a `PasswordProvider` wrapper object (#8849)
The idea here is to abstract out all the conditional code which tests which
methods a given password provider has, to provide a consistent interface.
2020-12-02 10:38:50 +00:00
Andrew Morgan edb3d3f827
Allow specifying room version in 'RestHelper.create_room_as' and add typing (#8854)
This PR adds a `room_version` argument to the `RestHelper`'s `create_room_as` function for tests. I plan to use this for testing knocking, which currently uses an unstable room version.
2020-12-02 10:38:18 +00:00
13 changed files with 427 additions and 125 deletions

View File

@ -6,7 +6,7 @@
set -ex
apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"

1
changelog.d/8800.misc Normal file
View File

@ -0,0 +1 @@
Add additional error checking for OpenID Connect and SAML mapping providers.

1
changelog.d/8849.misc Normal file
View File

@ -0,0 +1 @@
Refactor `password_auth_provider` support code.

1
changelog.d/8854.misc Normal file
View File

@ -0,0 +1 @@
Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.

1
changelog.d/8855.feature Normal file
View File

@ -0,0 +1 @@
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = []
for module, config in hs.config.password_providers:
try:
self.password_providers.append(
module(config=config, account_handler=account_handler)
)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]
logger.info("Extra password_providers: %r", self.password_providers)
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@ -853,6 +850,8 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
login_type = login_submission.get("type")
if not isinstance(login_type, str):
raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
# ideally, we wouldn't be checking the identifier unless we know we have a login
# method which uses it (https://github.com/matrix-org/synapse/issues/8836)
@ -998,24 +997,12 @@ class AuthHandler(BaseHandler):
qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
# we already checked that we have a valid login type
assert isinstance(login_type, str)
known_login_type = False
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
# we've already checked that there is a (valid) password field
is_valid = await provider.check_password(
qualified_user_id, login_submission["password"]
)
if is_valid:
return qualified_user_id, None
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
if hasattr(provider, "check_3pid_auth"):
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
return result
result = await provider.check_3pid_auth(medium, address, password)
if result:
return result
return None, None
@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
await provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
# delete pushers associated with this access token
if user_info.token_id is not None:
@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
for token, token_id, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@ -1519,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
self._supported_login_types = {}
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
def __str__(self):
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
login_type: the login type being attempted - one of the types returned by
get_supported_login_types()
login_dict: the dictionary of login secrets passed by the client.
Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
return None
result = await g(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
return result
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result

View File

@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
return UserAttributes(**attributes)
async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
# If allowing existing users we want to generate a single localpart
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
previously_registered_user_id = next(iter(users))
elif user_id in users:
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, users
)
)
return previously_registered_user_id
return None
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
self._allow_existing_users,
grandfather_existing_users,
)

View File

@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails"),
emails=result.get("emails", []),
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
return None
with (await self._mapping_lock.queue(self._auth_provider_id)):
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
grandfather_existing_users,
)
def expire_sessions(self):

View File

@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
allow_existing_users: bool = False,
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
if it has that matrix ID is returned regardless of the current mapping
logic.
If a callable is provided for grandfathering users, it is called and can
potentially return a matrix ID to use. If it does, the SSO ID is linked to
this matrix ID for subsequent calls.
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
If allow_existing_users is true the mapping function is only called once
and results in:
1. The use of a previously registered matrix ID. In this case, the
SSO ID is linked to the matrix ID. (Note it is possible that
other SSO IDs are linked to the same matrix ID.)
2. An unused localpart, in which case the user is registered (as
discussed above).
3. An error if the generated localpart matches multiple pre-existing
matrix IDs. Generally this should not happen.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
allow_existing_users: True if the localpart returned from the
mapping provider can be linked to an existing matrix ID.
grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned
matrix ID.
Returns:
The user ID associated with the SSO response.
@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
if previously_registered_user_id:
return previously_registered_user_id
# Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
# Note, if allow_existing_users is true then the loop is guaranteed
# to end on the first iteration: either by matching an existing user,
# raising an error, or registering a new user. See the docstring for
# more in-depth an explanation.
if users and allow_existing_users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
previously_registered_user_id = next(iter(users))
elif user_id in users:
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, users
)
)
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
elif not users:
if not await self.store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:

View File

@ -23,7 +23,7 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
@ -127,13 +127,8 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
config = self.default_config()
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {
"enabled": True,
@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
http_client=self.http_client,
proxied_http_client=self.http_client,
config=config,
)
return config
self.handler = OidcHandler(hs)
def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test"
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
sso_handler.render_error = self.render_error
# Reduce the number of attempts when generating MXIDs.
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
# in this case we'll just use the same username. A more realistic example
@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential users for a particular username.
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)

View File

@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# there are no valid flows here!
self.assertEqual(channel.json_body["flows"], [])
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()

168
tests/handlers/test_saml.py Normal file
View File

@ -0,0 +1,168 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attr
from synapse.handlers.sso import MappingException
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
class TestMappingProvider:
def __init__(self, config, module):
pass
@staticmethod
def parse_config(config):
return