Simplify the flow for SSO UIA (#8881)
* SsoHandler: remove inheritance from BaseHandler * Simplify the flow for SSO UIA We don't need to do all the magic for mapping users when we are doing UIA, so let's factor that out.pull/8897/head
parent
025fa06fc7
commit
36ba73f53d
|
@ -0,0 +1 @@
|
||||||
|
Simplify logic for handling user-interactive-auth via single-sign-on servers.
|
1
mypy.ini
1
mypy.ini
|
@ -43,6 +43,7 @@ files =
|
||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
synapse/handlers/room_member_worker.py,
|
synapse/handlers/room_member_worker.py,
|
||||||
synapse/handlers/saml_handler.py,
|
synapse/handlers/saml_handler.py,
|
||||||
|
synapse/handlers/sso.py,
|
||||||
synapse/handlers/sync.py,
|
synapse/handlers/sync.py,
|
||||||
synapse/handlers/ui_auth,
|
synapse/handlers/ui_auth,
|
||||||
synapse/http/client.py,
|
synapse/http/client.py,
|
||||||
|
|
|
@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
|
||||||
class BaseHandler:
|
class BaseHandler:
|
||||||
"""
|
"""
|
||||||
Common base class for the event handlers.
|
Common base class for the event handlers.
|
||||||
|
|
||||||
|
Deprecated: new code should not use this. Instead, Handler classes should define the
|
||||||
|
fields they actually need. The utility methods should either be factored out to
|
||||||
|
standalone helper functions, or to different Handler classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
|
|
@ -36,6 +36,8 @@ import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
@ -1331,15 +1333,14 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_sso_ui_auth(
|
async def complete_sso_ui_auth(
|
||||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
self, registered_user_id: str, session_id: str, request: Request,
|
||||||
):
|
):
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
"""Having figured out a mxid for this user, complete the HTTP request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registered_user_id: The registered user ID to complete SSO login for.
|
registered_user_id: The registered user ID to complete SSO login for.
|
||||||
|
session_id: The ID of the user-interactive auth session.
|
||||||
request: The request to complete.
|
request: The request to complete.
|
||||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
|
||||||
process.
|
|
||||||
"""
|
"""
|
||||||
# Mark the stage of the authentication as successful.
|
# Mark the stage of the authentication as successful.
|
||||||
# Save the user who authenticated with SSO, this will be used to ensure
|
# Save the user who authenticated with SSO, this will be used to ensure
|
||||||
|
@ -1355,7 +1356,7 @@ class AuthHandler(BaseHandler):
|
||||||
async def complete_sso_login(
|
async def complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
|
@ -1383,7 +1384,7 @@ class AuthHandler(BaseHandler):
|
||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
self,
|
self,
|
||||||
registered_user_id: str,
|
registered_user_id: str,
|
||||||
request: SynapseRequest,
|
request: Request,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
|
||||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# first check if we're doing a UIA
|
||||||
|
if ui_auth_session_id:
|
||||||
|
try:
|
||||||
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Could not extract remote user id")
|
||||||
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
|
self._auth_provider_id, remote_user_id, ui_auth_session_id, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# otherwise, it's a login
|
||||||
|
|
||||||
# Pull out the user-agent and IP from the request.
|
# Pull out the user-agent and IP from the request.
|
||||||
user_agent = request.get_user_agent("")
|
user_agent = request.get_user_agent("")
|
||||||
ip_address = self.hs.get_ip_from_request(request)
|
ip_address = self.hs.get_ip_from_request(request)
|
||||||
|
@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
|
||||||
extra_attributes = await get_extra_attributes(userinfo, token)
|
extra_attributes = await get_extra_attributes(userinfo, token)
|
||||||
|
|
||||||
# and finally complete the login
|
# and finally complete the login
|
||||||
if ui_auth_session_id:
|
await self._auth_handler.complete_sso_login(
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
user_id, request, client_redirect_url, extra_attributes
|
||||||
user_id, ui_auth_session_id, request
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self._auth_handler.complete_sso_login(
|
|
||||||
user_id, request, client_redirect_url, extra_attributes
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_oidc_session_token(
|
def _generate_oidc_session_token(
|
||||||
self,
|
self,
|
||||||
|
@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
|
||||||
The mxid of the user
|
The mxid of the user
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise MappingException(
|
raise MappingException(
|
||||||
"Failed to extract subject from OIDC response: %s" % (e,)
|
"Failed to extract subject from OIDC response: %s" % (e,)
|
||||||
)
|
)
|
||||||
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
|
||||||
# to be strings.
|
|
||||||
remote_user_id = str(remote_user_id)
|
|
||||||
|
|
||||||
# Older mapping providers don't accept the `failures` argument, so we
|
# Older mapping providers don't accept the `failures` argument, so we
|
||||||
# try and detect support.
|
# try and detect support.
|
||||||
|
@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
|
||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||||
|
"""Extract the unique remote id from an OIDC UserInfo block
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userinfo: An object representing the user given by the OIDC provider
|
||||||
|
Returns:
|
||||||
|
remote user id
|
||||||
|
"""
|
||||||
|
remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
|
||||||
|
# Some OIDC providers use integer IDs, but Synapse expects external IDs
|
||||||
|
# to be strings.
|
||||||
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
|
||||||
UserAttributeDict = TypedDict(
|
UserAttributeDict = TypedDict(
|
||||||
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
||||||
|
|
|
@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
|
||||||
saml2_auth.in_response_to, None
|
saml2_auth.in_response_to, None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# first check if we're doing a UIA
|
||||||
|
if current_session and current_session.ui_auth_session_id:
|
||||||
|
try:
|
||||||
|
remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
|
||||||
|
except MappingException as e:
|
||||||
|
logger.exception("Failed to extract remote user id from SAML response")
|
||||||
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
|
self._auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
current_session.ui_auth_session_id,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# otherwise, we're handling a login request.
|
||||||
|
|
||||||
# Ensure that the attributes of the logged in user meet the required
|
# Ensure that the attributes of the logged in user meet the required
|
||||||
# attributes.
|
# attributes.
|
||||||
for requirement in self._saml2_attribute_requirements:
|
for requirement in self._saml2_attribute_requirements:
|
||||||
|
@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
|
||||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Complete the interactive auth session or the login.
|
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||||
if current_session and current_session.ui_auth_session_id:
|
|
||||||
await self._auth_handler.complete_sso_ui_auth(
|
|
||||||
user_id, current_session.ui_auth_session_id, request
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
|
||||||
|
|
||||||
async def _map_saml_response_to_user(
|
async def _map_saml_response_to_user(
|
||||||
self,
|
self,
|
||||||
|
@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
|
||||||
RedirectException: some mapping providers may raise this if they need
|
RedirectException: some mapping providers may raise this if they need
|
||||||
to redirect to an interstitial page.
|
to redirect to an interstitial page.
|
||||||
"""
|
"""
|
||||||
|
remote_user_id = self._remote_id_from_saml_response(
|
||||||
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
|
||||||
saml2_auth, client_redirect_url
|
saml2_auth, client_redirect_url
|
||||||
)
|
)
|
||||||
|
|
||||||
if not remote_user_id:
|
|
||||||
raise MappingException(
|
|
||||||
"Failed to extract remote user id from SAML response"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def saml_response_to_remapped_user_attributes(
|
async def saml_response_to_remapped_user_attributes(
|
||||||
failures: int,
|
failures: int,
|
||||||
) -> UserAttributes:
|
) -> UserAttributes:
|
||||||
|
@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
|
||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remote_id_from_saml_response(
|
||||||
|
self,
|
||||||
|
saml2_auth: saml2.response.AuthnResponse,
|
||||||
|
client_redirect_url: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
"""Extract the unique remote id from a SAML2 AuthnResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
saml2_auth: The parsed SAML2 response.
|
||||||
|
client_redirect_url: The redirect URL passed in by the client.
|
||||||
|
Returns:
|
||||||
|
remote user id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MappingException if there was an error extracting the user id
|
||||||
|
"""
|
||||||
|
# It's not obvious why we need to pass in the redirect URI to the mapping
|
||||||
|
# provider, but we do :/
|
||||||
|
remote_user_id = self._user_mapping_provider.get_remote_user_id(
|
||||||
|
saml2_auth, client_redirect_url
|
||||||
|
)
|
||||||
|
|
||||||
|
if not remote_user_id:
|
||||||
|
raise MappingException(
|
||||||
|
"Failed to extract remote user id from SAML response"
|
||||||
|
)
|
||||||
|
|
||||||
|
return remote_user_id
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||||
to_expire = set()
|
to_expire = set()
|
||||||
|
|
|
@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.web.http import Request
|
||||||
|
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
from synapse.handlers._base import BaseHandler
|
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.types import UserID, contains_invalid_mxid_characters
|
from synapse.types import UserID, contains_invalid_mxid_characters
|
||||||
|
|
||||||
|
@ -42,14 +43,16 @@ class UserAttributes:
|
||||||
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
||||||
|
|
||||||
|
|
||||||
class SsoHandler(BaseHandler):
|
class SsoHandler:
|
||||||
# The number of attempts to ask the mapping provider for when generating an MXID.
|
# The number of attempts to ask the mapping provider for when generating an MXID.
|
||||||
_MAP_USERNAME_RETRIES = 1000
|
_MAP_USERNAME_RETRIES = 1000
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
self._store = hs.get_datastore()
|
||||||
|
self._server_name = hs.hostname
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._error_template = hs.config.sso_error_template
|
self._error_template = hs.config.sso_error_template
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
self, request, error: str, error_description: Optional[str] = None
|
||||||
|
@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we already have a mapping for this user.
|
# Check if we already have a mapping for this user.
|
||||||
previously_registered_user_id = await self.store.get_user_by_external_id(
|
previously_registered_user_id = await self._store.get_user_by_external_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
|
||||||
previously_registered_user_id = await grandfather_existing_users()
|
previously_registered_user_id = await grandfather_existing_users()
|
||||||
if previously_registered_user_id:
|
if previously_registered_user_id:
|
||||||
# Future logins should also match this user ID.
|
# Future logins should also match this user ID.
|
||||||
await self.store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
)
|
)
|
||||||
return previously_registered_user_id
|
return previously_registered_user_id
|
||||||
|
@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
||||||
if not await self.store.get_users_by_id_case_insensitive(user_id):
|
if not await self._store.get_users_by_id_case_insensitive(user_id):
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
|
||||||
user_agent_ips=[(user_agent, ip_address)],
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
await self._store.record_user_external_id(
|
||||||
auth_provider_id, remote_user_id, registered_user_id
|
auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
|
async def complete_sso_ui_auth_request(
|
||||||
|
self,
|
||||||
|
auth_provider_id: str,
|
||||||
|
remote_user_id: str,
|
||||||
|
ui_auth_session_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Given an SSO ID, retrieve the user ID for it and complete UIA.
|
||||||
|
|
||||||
|
Note that this requires that the user is mapped in the "user_external_ids"
|
||||||
|
table. This will be the case if they have ever logged in via SAML or OIDC in
|
||||||
|
recentish synapse versions, but may not be for older users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
|
"oidc" or "saml".
|
||||||
|
remote_user_id: The unique identifier from the SSO provider.
|
||||||
|
ui_auth_session_id: The ID of the user-interactive auth session.
|
||||||
|
request: The request to complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
|
auth_provider_id, remote_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
logger.warning(
|
||||||
|
"Remote user %s/%s has not previously logged in here: UIA will fail",
|
||||||
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
)
|
||||||
|
# Let the UIA flow handle this the same as if they presented creds for a
|
||||||
|
# different user.
|
||||||
|
user_id = ""
|
||||||
|
|
||||||
|
await self._auth_handler.complete_sso_ui_auth(
|
||||||
|
user_id, ui_auth_session_id, request
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue