Push login completion down into SsoHandler (#8941)

This is another part of my work towards fixing #8876. It moves some of the logic currently in the SAML and OIDC handlers - in particular the call to `AuthHandler.complete_sso_login` down into the `SsoHandler`.
pull/8964/head
Richard van der Hoff 2020-12-16 20:01:53 +00:00 committed by GitHub
parent 44b7d4c6d6
commit e1b8e37f93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 86 deletions

1
changelog.d/8941.feature Normal file
View File

@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View File

@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
@ -689,33 +687,14 @@ class OidcHandler(BaseHandler):
# otherwise, it's a login # otherwise, it's a login
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
user_id = await self._map_userinfo_to_user( await self._complete_oidc_login(
userinfo, token, user_agent, ip_address userinfo, token, request, client_redirect_url
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token( def _generate_oidc_session_token(
self, self,
@ -838,10 +817,14 @@ class OidcHandler(BaseHandler):
now = self.clock.time_msec() now = self.clock.time_msec()
return now < expiry return now < expiry
async def _map_userinfo_to_user( async def _complete_oidc_login(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str self,
) -> str: userinfo: UserInfo,
"""Maps a UserInfo object to a mxid. token: Token,
request: SynapseRequest,
client_redirect_url: str,
) -> None:
"""Given a UserInfo response, complete the login flow
UserInfo should have a claim that uniquely identifies users. This claim UserInfo should have a claim that uniquely identifies users. This claim
is usually `sub`, but can be configured with `oidc_config.subject_claim`. is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@ -853,17 +836,16 @@ class OidcHandler(BaseHandler):
If a user already exists with the mxid we've mapped and allow_existing_users If a user already exists with the mxid we've mapped and allow_existing_users
is disabled, raise an exception. is disabled, raise an exception.
Otherwise, render a redirect back to the client_redirect_url with a loginToken.
Args: Args:
userinfo: an object representing the user userinfo: an object representing the user
token: a dict with the tokens obtained from the provider token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request. request: The request to respond to
ip_address: The IP address of the client making the request. client_redirect_url: The redirect URL passed in by the client.
Raises: Raises:
MappingException: if there was an error while mapping some properties MappingException: if there was an error while mapping some properties
Returns:
The mxid of the user
""" """
try: try:
remote_user_id = self._remote_id_from_userinfo(userinfo) remote_user_id = self._remote_id_from_userinfo(userinfo)
@ -931,13 +913,23 @@ class OidcHandler(BaseHandler):
return None return None
return await self._sso_handler.get_mxid_from_sso( # Mapping providers might not have get_extra_attributes: only call this
# method if it exists.
extra_attributes = None
get_extra_attributes = getattr(
self._user_mapping_provider, "get_extra_attributes", None
)
if get_extra_attributes:
extra_attributes = await get_extra_attributes(userinfo, token)
await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self._auth_provider_id,
remote_user_id, remote_user_id,
user_agent, request,
ip_address, client_redirect_url,
oidc_response_to_user_attributes, oidc_response_to_user_attributes,
grandfather_existing_users, grandfather_existing_users,
extra_attributes,
) )
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:

View File

@ -58,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs) super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
@ -229,40 +227,29 @@ class SamlHandler(BaseHandler):
) )
return return
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
user_id = await self._map_saml_response_to_user( await self._complete_saml_login(saml2_auth, request, relay_state)
saml2_auth, relay_state, user_agent, ip_address
)
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return
await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _complete_saml_login(
async def _map_saml_response_to_user(
self, self,
saml2_auth: saml2.response.AuthnResponse, saml2_auth: saml2.response.AuthnResponse,
request: SynapseRequest,
client_redirect_url: str, client_redirect_url: str,
user_agent: str, ) -> None:
ip_address: str,
) -> str:
""" """
Given a SAML response, retrieve the user ID for it and possibly register the user. Given a SAML response, complete the login flow
Retrieves the remote user ID, registers the user if necessary, and serves
a redirect back to the client with a login-token.
Args: Args:
saml2_auth: The parsed SAML2 response. saml2_auth: The parsed SAML2 response.
request: The request to respond to
client_redirect_url: The redirect URL passed in by the client. client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
Raises: Raises:
MappingException if there was a problem mapping the response to a user. MappingException if there was a problem mapping the response to a user.
@ -318,11 +305,11 @@ class SamlHandler(BaseHandler):
return None return None
return await self._sso_handler.get_mxid_from_sso( await self._sso_handler.complete_sso_login_request(
self._auth_provider_id, self._auth_provider_id,
remote_user_id, remote_user_id,
user_agent, request,
ip_address, client_redirect_url,
saml_response_to_remapped_user_attributes, saml_response_to_remapped_user_attributes,
grandfather_existing_users, grandfather_existing_users,
) )

View File

@ -21,7 +21,8 @@ from twisted.web.http import Request
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -119,15 +120,16 @@ class SsoHandler:
# No match. # No match.
return None return None
async def get_mxid_from_sso( async def complete_sso_login_request(
self, self,
auth_provider_id: str, auth_provider_id: str,
remote_user_id: str, remote_user_id: str,
user_agent: str, request: SynapseRequest,
ip_address: str, client_redirect_url: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str: extra_login_attributes: Optional[JsonDict] = None,
) -> None:
""" """
Given an SSO ID, retrieve the user ID for it and possibly register the user. Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -146,12 +148,18 @@ class SsoHandler:
given user-agent and IP address and the SSO ID is linked to this matrix given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls. ID for subsequent calls.
Finally, we generate a redirect to the supplied redirect uri, with a login token
Args: Args:
auth_provider_id: A unique identifier for this SSO provider, e.g. auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml". "oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider. remote_user_id: The unique identifier from the SSO provider.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request. request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
sso_to_matrix_id_mapper: A callable to generate the user attributes. sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed. times the returned mxid localpart mapping has failed.
@ -163,12 +171,13 @@ class SsoHandler:
to the user. to the user.
RedirectException to redirect to an additional page (e.g. RedirectException to redirect to an additional page (e.g.
to prompt the user for more information). to prompt the user for more information).
grandfather_existing_users: A callable which can return an previously grandfather_existing_users: A callable which can return an previously
existing matrix ID. The SSO ID is then linked to the returned existing matrix ID. The SSO ID is then linked to the returned
matrix ID. matrix ID.
Returns: extra_login_attributes: An optional dictionary of extra
The user ID associated with the SSO response. attributes to be provided to the client in the login response.
Raises: Raises:
MappingException if there was a problem mapping the response to a user. MappingException if there was a problem mapping the response to a user.
@ -181,28 +190,33 @@ class SsoHandler:
# interstitial pages. # interstitial pages.
with await self._mapping_lock.queue(auth_provider_id): with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user # first of all, check if we already have a mapping for this user
previously_registered_user_id = await self.get_sso_user_by_remote_user_id( user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id, auth_provider_id, remote_user_id,
) )
if previously_registered_user_id:
return previously_registered_user_id
# Check for grandfathering of users. # Check for grandfathering of users.
if grandfather_existing_users: if not user_id and grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users() user_id = await grandfather_existing_users()
if previously_registered_user_id: if user_id:
# Future logins should also match this user ID. # Future logins should also match this user ID.
await self._store.record_user_external_id( await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id auth_provider_id, remote_user_id, user_id
) )
return previously_registered_user_id
# Otherwise, generate a new user. # Otherwise, generate a new user.
if not user_id:
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
user_id = await self._register_mapped_user( user_id = await self._register_mapped_user(
attributes, auth_provider_id, remote_user_id, user_agent, ip_address, attributes,
auth_provider_id,
remote_user_id,
request.get_user_agent(""),
request.getClientIP(),
)
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_login_attributes
) )
return user_id
async def _call_attribute_mapper( async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],

View File

@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri" "@test_user:test", request, "redirect_uri", None
) )
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "" "@test_user:test", request, "", None
) )
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "") self.handler._handle_authn_response(request, saml_response, "")
) )
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "" "@test_user:test", request, "", None
) )
def test_map_saml_response_to_invalid_localpart(self): def test_map_saml_response_to_invalid_localpart(self):
@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, "" "@test_user1:test", request, "", None
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()