Do not use a separate endpoint for UI Auth for CAS.
parent
32581bf832
commit
7e7e48628d
|
@ -48,44 +48,37 @@ class CasHandler:
|
||||||
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
def _build_service_param(self, service_redirect_endpoint: str, **kwargs) -> str:
|
def _build_service_param(self, args: Dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a value to use as the "service" parameter when redirecting or
|
Generates a value to use as the "service" parameter when redirecting or
|
||||||
querying the CAS service.
|
querying the CAS service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_redirect_endpoint: The homeserver endpoint to redirect
|
args: Additional arguments to include in the final redirect URL.
|
||||||
the client to after successful SSO negotiation.
|
|
||||||
kwargs: Additional arguments to include in the final redirect URL.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The URL to use as a "service" parameter.
|
The URL to use as a "service" parameter.
|
||||||
"""
|
"""
|
||||||
return "%s%s?%s" % (
|
return "%s%s?%s" % (
|
||||||
self._cas_service_url,
|
self._cas_service_url,
|
||||||
service_redirect_endpoint,
|
"/_matrix/client/r0/login/cas/ticket",
|
||||||
urllib.parse.urlencode(kwargs),
|
urllib.parse.urlencode(args),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _validate_ticket(
|
async def _validate_ticket(
|
||||||
self, ticket: str, service_redirect_endpoint: str, client_redirect_url: str
|
self, ticket: str, service_args: Dict[str, str]
|
||||||
) -> Tuple[str, Optional[str]]:
|
) -> Tuple[str, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Validate a CAS ticket with the server, parse the response, and return the user and display name.
|
Validate a CAS ticket with the server, parse the response, and return the user and display name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ticket: The CAS ticket from the client.
|
ticket: The CAS ticket from the client.
|
||||||
service_redirect_endpoint: The homeserver endpoint that the client
|
service_args: Additional arguments to include in the service URL.
|
||||||
accessed to validate the ticket.
|
|
||||||
client_redirect_url: The URL to redirect the client to after
|
|
||||||
validation is done.
|
|
||||||
"""
|
"""
|
||||||
uri = self._cas_server_url + "/proxyValidate"
|
uri = self._cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
"ticket": ticket,
|
"ticket": ticket,
|
||||||
"service": self._build_service_param(
|
"service": self._build_service_param(service_args),
|
||||||
service_redirect_endpoint, redirectUrl=client_redirect_url
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
body = await self._http_client.get_raw(uri, args)
|
body = await self._http_client.get_raw(uri, args)
|
||||||
|
@ -154,26 +147,28 @@ class CasHandler:
|
||||||
)
|
)
|
||||||
return user, attributes
|
return user, attributes
|
||||||
|
|
||||||
def get_redirect_url(self, service_redirect_endpoint: str, **kwargs) -> str:
|
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a URL to the CAS server where the client should be redirected.
|
Generates a URL to the CAS server where the client should be redirected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_redirect_endpoint: The homeserver endpoint to redirect
|
service_args: Additional arguments to include in the final redirect URL.
|
||||||
the client to after successful SSO negotiation.
|
|
||||||
kwargs: Additional arguments to include in the final redirect URL.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The URL to redirect the client to.
|
The URL to redirect the client to.
|
||||||
"""
|
"""
|
||||||
args = urllib.parse.urlencode(
|
args = urllib.parse.urlencode(
|
||||||
{"service": self._build_service_param(service_redirect_endpoint, **kwargs)}
|
{"service": self._build_service_param(service_args)}
|
||||||
)
|
)
|
||||||
|
|
||||||
return "%s/login?%s" % (self._cas_server_url, args)
|
return "%s/login?%s" % (self._cas_server_url, args)
|
||||||
|
|
||||||
async def handle_ticket_for_login(
|
async def handle_ticket(
|
||||||
self, request: SynapseRequest, client_redirect_url: str, ticket: str,
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
ticket: str,
|
||||||
|
client_redirect_url: Optional[str],
|
||||||
|
session: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called once the user has successfully authenticated with the SSO,
|
Called once the user has successfully authenticated with the SSO,
|
||||||
|
@ -186,52 +181,35 @@ class CasHandler:
|
||||||
request: the incoming request from the browser. We'll
|
request: the incoming request from the browser. We'll
|
||||||
respond to it with a redirect.
|
respond to it with a redirect.
|
||||||
|
|
||||||
|
ticket: The CAS ticket provided by the client.
|
||||||
|
|
||||||
client_redirect_url: the redirect_url the client gave us when
|
client_redirect_url: the redirect_url the client gave us when
|
||||||
it first started the process.
|
it first started the process.
|
||||||
|
|
||||||
ticket: The CAS ticket provided by the client.
|
session_id: The UI Auth session ID, if applicable.
|
||||||
"""
|
"""
|
||||||
username, user_display_name = await self._validate_ticket(
|
args = {}
|
||||||
ticket, request.path, client_redirect_url
|
if client_redirect_url:
|
||||||
)
|
args["redirectUrl"] = client_redirect_url
|
||||||
|
if session:
|
||||||
|
args["session"] = session
|
||||||
|
username, user_display_name = await self._validate_ticket(ticket, args)
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(username)
|
localpart = map_username_to_mxid_localpart(username)
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
if not registered_user_id:
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
if session:
|
||||||
localpart=localpart, default_display_name=user_display_name
|
self._auth_handler.complete_sso_ui_auth(
|
||||||
|
registered_user_id, session, request,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._auth_handler.complete_sso_login(
|
else:
|
||||||
registered_user_id, request, client_redirect_url
|
if not registered_user_id:
|
||||||
)
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
|
localpart=localpart, default_display_name=user_display_name
|
||||||
|
)
|
||||||
|
|
||||||
async def handle_ticket_for_ui_auth(
|
self._auth_handler.complete_sso_login(
|
||||||
self, request: SynapseRequest, ticket: str, session_id: str
|
registered_user_id, request, client_redirect_url
|
||||||
) -> None:
|
)
|
||||||
"""
|
|
||||||
Called once the user has successfully authenticated with the SSO,
|
|
||||||
validates a CAS ticket sent by the client and completes user interactive
|
|
||||||
authentication.
|
|
||||||
|
|
||||||
If successful, this completes the SSO step of UI auth and returns a
|
|
||||||
an HTML page to the client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: the incoming request from the browser.
|
|
||||||
|
|
||||||
ticket: The CAS ticket provided by the client.
|
|
||||||
|
|
||||||
session_id: The UI Auth session ID.
|
|
||||||
"""
|
|
||||||
client_redirect_url = ""
|
|
||||||
user, _ = await self._validate_ticket(ticket, request.path, client_redirect_url)
|
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(user)
|
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
|
||||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
|
||||||
|
|
||||||
self._auth_handler.complete_sso_ui_auth(
|
|
||||||
registered_user_id, session_id, request,
|
|
||||||
)
|
|
||||||
|
|
|
@ -426,7 +426,7 @@ class CasRedirectServlet(BaseSSORedirectServlet):
|
||||||
|
|
||||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||||
return self._cas_handler.get_redirect_url(
|
return self._cas_handler.get_redirect_url(
|
||||||
"/_matrix/client/r0/login/cas/ticket", redirectUrl=client_redirect_url
|
{"redirectUrl": client_redirect_url}
|
||||||
).encode("ascii")
|
).encode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
@ -438,10 +438,20 @@ class CasTicketServlet(RestServlet):
|
||||||
self._cas_handler = hs.get_cas_handler()
|
self._cas_handler = hs.get_cas_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: SynapseRequest) -> None:
|
async def on_GET(self, request: SynapseRequest) -> None:
|
||||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
client_redirect_url = parse_string(request, "redirectUrl")
|
||||||
ticket = parse_string(request, "ticket", required=True)
|
ticket = parse_string(request, "ticket", required=True)
|
||||||
await self._cas_handler.handle_ticket_for_login(
|
|
||||||
request, client_redirect_url, ticket
|
# Maybe get a session ID (if this ticket is from user interactive
|
||||||
|
# authentication).
|
||||||
|
session = parse_string(request, "session")
|
||||||
|
|
||||||
|
# Either client_redirect_url or session must be provided.
|
||||||
|
if not client_redirect_url and not session:
|
||||||
|
message = "Missing string query parameter redirectUrl or session"
|
||||||
|
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
await self._cas_handler.handle_ticket(
|
||||||
|
request, ticket, client_redirect_url, session
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -145,7 +145,7 @@ class AuthRestServlet(RestServlet):
|
||||||
# Generate a request to CAS that redirects back to an endpoint
|
# Generate a request to CAS that redirects back to an endpoint
|
||||||
# to verify the successful authentication.
|
# to verify the successful authentication.
|
||||||
sso_redirect_url = self._cas_handler.get_redirect_url(
|
sso_redirect_url = self._cas_handler.get_redirect_url(
|
||||||
"/_matrix/client/r0/auth/cas/ticket", session=session,
|
{"session": session},
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self._saml_enabled:
|
elif self._saml_enabled:
|
||||||
|
@ -239,35 +239,5 @@ class AuthRestServlet(RestServlet):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
class CasAuthTicketServlet(RestServlet):
|
|
||||||
"""
|
|
||||||
Completes a user interactive authentication session when using CAS.
|
|
||||||
|
|
||||||
It is called after the user has completed SSO with the CAS provider and
|
|
||||||
received a ticket in response. It does the following:
|
|
||||||
|
|
||||||
* Retrieves the CAS ticket and the UI auth session from the request.
|
|
||||||
* Validates the CAS ticket.
|
|
||||||
* Marks the UI auth session as complete.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATTERNS = client_patterns(r"/auth/cas/ticket")
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
super(CasAuthTicketServlet, self).__init__()
|
|
||||||
self._cas_handler = hs.get_cas_handler()
|
|
||||||
|
|
||||||
async def on_GET(self, request):
|
|
||||||
ticket = parse_string(request, "ticket", required=True)
|
|
||||||
# Pull the UI Auth session ID out.
|
|
||||||
session_id = parse_string(request, "session", required=True)
|
|
||||||
|
|
||||||
return await self._cas_handler.handle_ticket_for_ui_auth(
|
|
||||||
request, ticket, session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
AuthRestServlet(hs).register(http_server)
|
AuthRestServlet(hs).register(http_server)
|
||||||
if hs.config.cas_enabled:
|
|
||||||
CasAuthTicketServlet(hs).register(http_server)
|
|
||||||
|
|
Loading…
Reference in New Issue