Push login completion down into SsoHandler (#8941)
This is another part of my work towards fixing #8876. It moves some of the logic currently in the SAML and OIDC handlers - in particular the call to `AuthHandler.complete_sso_login` down into the `SsoHandler`.pull/8964/head
parent
44b7d4c6d6
commit
e1b8e37f93
|
@ -0,0 +1 @@
|
||||||
|
Add support for allowing users to pick their own user ID during a single-sign-on login.
|
|
@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
|
||||||
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
|
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
|
||||||
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
|
||||||
self._registration_handler = hs.get_registration_handler()
|
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name = hs.config.server_name # type: str
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
|
|
||||||
|
@ -689,33 +687,14 @@ class OidcHandler(BaseHandler):
|
||||||
|
|
||||||
# otherwise, it's a login
|
# otherwise, it's a login
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
|
||||||
user_agent = request.get_user_agent("")
|
|
||||||
ip_address = self.hs.get_ip_from_request(request)
|
|
||||||
|
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
try:
|
try:
|
||||||
user_id = await self._map_userinfo_to_user(
|
await self._complete_oidc_login(
|
||||||
userinfo, token, user_agent, ip_address
|
userinfo, token, request, client_redirect_url
|
||||||
)
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
|
||||||
|
|
||||||
# Mapping providers might not have get_extra_attributes: only call this
|
|
||||||
# method if it exists.
|
|
||||||
extra_attributes = None
|
|
||||||
get_extra_attributes = getattr(
|
|
||||||
self._user_mapping_provider, "get_extra_attributes", None
|
|
||||||
)
|
|
||||||
if get_extra_attributes:
|
|
||||||
extra_attributes = await get_extra_attributes(userinfo, token)
|
|
||||||
|
|
||||||
# and finally complete the login
|
|
||||||
await self._auth_handler.complete_sso_login(
|
|
||||||
user_id, request, client_redirect_url, extra_attributes
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_oidc_session_token(
|
def _generate_oidc_session_token(
|
||||||
self,
|
self,
|
||||||
|
@ -838,10 +817,14 @@ class OidcHandler(BaseHandler):
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _map_userinfo_to_user(
|
async def _complete_oidc_login(
|
||||||
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
|
self,
|
||||||
) -> str:
|
userinfo: UserInfo,
|
||||||
"""Maps a UserInfo object to a mxid.
|
token: Token,
|
||||||
|
request: SynapseRequest,
|
||||||
|
client_redirect_url: str,
|
||||||
|
) -> None:
|
||||||
|
"""Given a UserInfo response, complete the login flow
|
||||||
|
|
||||||
UserInfo should have a claim that uniquely identifies users. This claim
|
UserInfo should have a claim that uniquely identifies users. This claim
|
||||||
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
|
is usually `sub`, but can be configured with `oidc_config.subject_claim`.
|
||||||
|
@ -853,17 +836,16 @@ class OidcHandler(BaseHandler):
|
||||||
If a user already exists with the mxid we've mapped and allow_existing_users
|
If a user already exists with the mxid we've mapped and allow_existing_users
|
||||||
is disabled, raise an exception.
|
is disabled, raise an exception.
|
||||||
|
|
||||||
|
Otherwise, render a redirect back to the client_redirect_url with a loginToken.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
userinfo: an object representing the user
|
userinfo: an object representing the user
|
||||||
token: a dict with the tokens obtained from the provider
|
token: a dict with the tokens obtained from the provider
|
||||||
user_agent: The user agent of the client making the request.
|
request: The request to respond to
|
||||||
ip_address: The IP address of the client making the request.
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MappingException: if there was an error while mapping some properties
|
MappingException: if there was an error while mapping some properties
|
||||||
|
|
||||||
Returns:
|
|
||||||
The mxid of the user
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
|
@ -931,13 +913,23 @@ class OidcHandler(BaseHandler):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
# Mapping providers might not have get_extra_attributes: only call this
|
||||||
|
# method if it exists.
|
||||||
|
extra_attributes = None
|
||||||
|
get_extra_attributes = getattr(
|
||||||
|
self._user_mapping_provider, "get_extra_attributes", None
|
||||||
|
)
|
||||||
|
if get_extra_attributes:
|
||||||
|
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||||
|
|
||||||
|
await self._sso_handler.complete_sso_login_request(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
request,
|
||||||
ip_address,
|
client_redirect_url,
|
||||||
oidc_response_to_user_attributes,
|
oidc_response_to_user_attributes,
|
||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
|
extra_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||||
|
|
|
@ -58,8 +58,6 @@ class SamlHandler(BaseHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._saml_idp_entityid = hs.config.saml2_idp_entityid
|
self._saml_idp_entityid = hs.config.saml2_idp_entityid
|
||||||
self._auth_handler = hs.get_auth_handler()
|
|
||||||
self._registration_handler = hs.get_registration_handler()
|
|
||||||
|
|
||||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||||
self._grandfathered_mxid_source_attribute = (
|
self._grandfathered_mxid_source_attribute = (
|
||||||
|
@ -229,40 +227,29 @@ class SamlHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
|
||||||
user_agent = request.get_user_agent("")
|
|
||||||
ip_address = self.hs.get_ip_from_request(request)
|
|
||||||
|
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
try:
|
try:
|
||||||
user_id = await self._map_saml_response_to_user(
|
await self._complete_saml_login(saml2_auth, request, relay_state)
|
||||||
saml2_auth, relay_state, user_agent, ip_address
|
|
||||||
)
|
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
|
||||||
|
|
||||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
async def _complete_saml_login(
|
||||||
|
|
||||||
async def _map_saml_response_to_user(
|
|
||||||
self,
|
self,
|
||||||
saml2_auth: saml2.response.AuthnResponse,
|
saml2_auth: saml2.response.AuthnResponse,
|
||||||
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
user_agent: str,
|
) -> None:
|
||||||
ip_address: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Given a SAML response, retrieve the user ID for it and possibly register the user.
|
Given a SAML response, complete the login flow
|
||||||
|
|
||||||
|
Retrieves the remote user ID, registers the user if necessary, and serves
|
||||||
|
a redirect back to the client with a login-token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
saml2_auth: The parsed SAML2 response.
|
saml2_auth: The parsed SAML2 response.
|
||||||
|
request: The request to respond to
|
||||||
client_redirect_url: The redirect URL passed in by the client.
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
user_agent: The user agent of the client making the request.
|
|
||||||
ip_address: The IP address of the client making the request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The user ID associated with this response.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MappingException if there was a problem mapping the response to a user.
|
MappingException if there was a problem mapping the response to a user.
|
||||||
|
@ -318,11 +305,11 @@ class SamlHandler(BaseHandler):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
await self._sso_handler.complete_sso_login_request(
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
user_agent,
|
request,
|
||||||
ip_address,
|
client_redirect_url,
|
||||||
saml_response_to_remapped_user_attributes,
|
saml_response_to_remapped_user_attributes,
|
||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,8 @@ from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.types import UserID, contains_invalid_mxid_characters
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -119,15 +120,16 @@ class SsoHandler:
|
||||||
# No match.
|
# No match.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_mxid_from_sso(
|
async def complete_sso_login_request(
|
||||||
self,
|
self,
|
||||||
auth_provider_id: str,
|
auth_provider_id: str,
|
||||||
remote_user_id: str,
|
remote_user_id: str,
|
||||||
user_agent: str,
|
request: SynapseRequest,
|
||||||
ip_address: str,
|
client_redirect_url: str,
|
||||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
||||||
) -> str:
|
extra_login_attributes: Optional[JsonDict] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||||
|
|
||||||
|
@ -146,12 +148,18 @@ class SsoHandler:
|
||||||
given user-agent and IP address and the SSO ID is linked to this matrix
|
given user-agent and IP address and the SSO ID is linked to this matrix
|
||||||
ID for subsequent calls.
|
ID for subsequent calls.
|
||||||
|
|
||||||
|
Finally, we generate a redirect to the supplied redirect uri, with a login token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
"oidc" or "saml".
|
"oidc" or "saml".
|
||||||
|
|
||||||
remote_user_id: The unique identifier from the SSO provider.
|
remote_user_id: The unique identifier from the SSO provider.
|
||||||
user_agent: The user agent of the client making the request.
|
|
||||||
ip_address: The IP address of the client making the request.
|
request: The request to respond to
|
||||||
|
|
||||||
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
|
|
||||||
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
||||||
The only parameter is an integer which represents the amount of
|
The only parameter is an integer which represents the amount of
|
||||||
times the returned mxid localpart mapping has failed.
|
times the returned mxid localpart mapping has failed.
|
||||||
|
@ -163,12 +171,13 @@ class SsoHandler:
|
||||||
to the user.
|
to the user.
|
||||||
RedirectException to redirect to an additional page (e.g.
|
RedirectException to redirect to an additional page (e.g.
|
||||||
to prompt the user for more information).
|
to prompt the user for more information).
|
||||||
|
|
||||||
grandfather_existing_users: A callable which can return an previously
|
grandfather_existing_users: A callable which can return an previously
|
||||||
existing matrix ID. The SSO ID is then linked to the returned
|
existing matrix ID. The SSO ID is then linked to the returned
|
||||||
matrix ID.
|
matrix ID.
|
||||||
|
|
||||||
Returns:
|
extra_login_attributes: An optional dictionary of extra
|
||||||
The user ID associated with the SSO response.
|
attributes to be provided to the client in the login response.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MappingException if there was a problem mapping the response to a user.
|
MappingException if there was a problem mapping the response to a user.
|
||||||
|
@ -181,28 +190,33 @@ class SsoHandler:
|
||||||
# interstitial pages.
|
# interstitial pages.
|
||||||
with await self._mapping_lock.queue(auth_provider_id):
|
with await self._mapping_lock.queue(auth_provider_id):
|
||||||
# first of all, check if we already have a mapping for this user
|
# first of all, check if we already have a mapping for this user
|
||||||
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
if previously_registered_user_id:
|
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
# Check for grandfathering of users.
|
# Check for grandfathering of users.
|
||||||
if grandfather_existing_users:
|
if not user_id and grandfather_existing_users:
|
||||||
previously_registered_user_id = await grandfather_existing_users()
|
user_id = await grandfather_existing_users()
|
||||||
if previously_registered_user_id:
|
if user_id:
|
||||||
# Future logins should also match this user ID.
|
# Future logins should also match this user ID.
|
||||||
await self._store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
auth_provider_id, remote_user_id, user_id
|
||||||
)
|
)
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
# Otherwise, generate a new user.
|
# Otherwise, generate a new user.
|
||||||
|
if not user_id:
|
||||||
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
|
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
|
||||||
user_id = await self._register_mapped_user(
|
user_id = await self._register_mapped_user(
|
||||||
attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
|
attributes,
|
||||||
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
request.get_user_agent(""),
|
||||||
|
request.getClientIP(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._auth_handler.complete_sso_login(
|
||||||
|
user_id, request, client_redirect_url, extra_login_attributes
|
||||||
)
|
)
|
||||||
return user_id
|
|
||||||
|
|
||||||
async def _call_attribute_mapper(
|
async def _call_attribute_mapper(
|
||||||
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
|
|
|
@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, "redirect_uri"
|
"@test_user:test", request, "redirect_uri", None
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
|
@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, ""
|
"@test_user:test", request, "", None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
|
@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._handle_authn_response(request, saml_response, "")
|
self.handler._handle_authn_response(request, saml_response, "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", request, ""
|
"@test_user:test", request, "", None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
|
@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", request, ""
|
"@test_user1:test", request, "", None
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue