Re-structure CAS code to be more similar to SAML/OIDC.

pull/8784/head
Patrick Cloke 2020-11-20 13:03:57 -05:00
parent 2819f837c9
commit 3b9ce2a608
1 changed files with 44 additions and 19 deletions

View File

@ -203,32 +203,57 @@ class CasHandler:
args["session"] = session args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args) username, user_display_name = await self._validate_ticket(ticket, args)
localpart = map_username_to_mxid_localpart(username) # Pull out the user-agent and IP from the request.
user_id = UserID(localpart, self._hostname).to_string() user_agent = request.get_user_agent("")
registered_user_id = await self._auth_handler.check_user_exists(user_id) ip_address = self.hs.get_ip_from_request(request)
# Get the matrix ID from the CAS username.
user_id = await self._map_cas_user_to_matrix_user(
username, user_display_name, user_agent, ip_address
)
if session: if session:
# If there's a session then the user must already exist.
assert registered_user_id
await self._auth_handler.complete_sso_ui_auth( await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request, user_id, session, request,
) )
else: else:
# If this not a UI auth request than there must be a redirect URL. # If this not a UI auth request than there must be a redirect URL.
assert client_redirect_url assert client_redirect_url
if not registered_user_id: await self._auth_handler.complete_sso_login(
# Pull out the user-agent and IP from the request. user_id, request, client_redirect_url
user_agent = request.get_user_agent("") )
ip_address = self.hs.get_ip_from_request(request)
async def _map_cas_user_to_matrix_user(
self,
remote_user_id: str,
display_name: Optional[str],
user_agent: str,
ip_address: str,
) -> str:
"""
Given a CAS username, retrieve the user ID for it and possibly register the user.
Args:
remote_user_id: The username from the CAS response.
display_name: The display name from the CAS response.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
"""
localpart = map_username_to_mxid_localpart(remote_user_id)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = await self._auth_handler.check_user_exists(user_id)
# If the user does not exist, register it.
if not registered_user_id:
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, localpart=localpart,
default_display_name=user_display_name, default_display_name=display_name,
user_agent_ips=[(user_agent, ip_address)], user_agent_ips=[(user_agent, ip_address)],
) )
await self._auth_handler.complete_sso_login( return registered_user_id
registered_user_id, request, client_redirect_url
)