Refactor the CAS handler in prep for using the abstracted SSO code. (#8958)
This makes the CAS handler look more like the SAML/OIDC handlers: * Render errors to users instead of throwing JSON errors. * Internal reorganization.pull/8856/head
							parent
							
								
									56e00ca85e
								
							
						
					
					
						commit
						4218473f9e
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Properly store the mapping of external ID to Matrix ID for CAS users.
 | 
			
		||||
| 
						 | 
				
			
			@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
 | 
			
		|||
You should now have a Django project configured to serve CAS authentication with
 | 
			
		||||
a single user created.
 | 
			
		||||
 | 
			
		||||
## Configure Synapse (and Riot) to use CAS
 | 
			
		||||
## Configure Synapse (and Element) to use CAS
 | 
			
		||||
 | 
			
		||||
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
 | 
			
		||||
   running Django test server:
 | 
			
		||||
| 
						 | 
				
			
			@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
 | 
			
		|||
 | 
			
		||||
## Testing the configuration
 | 
			
		||||
 | 
			
		||||
Then in Riot:
 | 
			
		||||
Then in Element:
 | 
			
		||||
 | 
			
		||||
1. Visit the login page with a Riot pointing at your homeserver.
 | 
			
		||||
1. Visit the login page with a Element pointing at your homeserver.
 | 
			
		||||
2. Click the Single Sign-On button.
 | 
			
		||||
3. Login using the credentials created with `createsuperuser`.
 | 
			
		||||
4. You should be logged in.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,13 +13,15 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import logging
 | 
			
		||||
import urllib
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
 | 
			
		||||
import urllib.parse
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Optional
 | 
			
		||||
from xml.etree import ElementTree as ET
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
 | 
			
		||||
from twisted.web.client import PartialDownloadError
 | 
			
		||||
 | 
			
		||||
from synapse.api.errors import Codes, LoginError
 | 
			
		||||
from synapse.api.errors import HttpResponseException
 | 
			
		||||
from synapse.http.site import SynapseRequest
 | 
			
		||||
from synapse.types import UserID, map_username_to_mxid_localpart
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -29,6 +31,26 @@ if TYPE_CHECKING:
 | 
			
		|||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CasError(Exception):
 | 
			
		||||
    """Used to catch errors when validating the CAS ticket.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, error, error_description=None):
 | 
			
		||||
        self.error = error
 | 
			
		||||
        self.error_description = error_description
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        if self.error_description:
 | 
			
		||||
            return "{}: {}".format(self.error, self.error_description)
 | 
			
		||||
        return self.error
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@attr.s(slots=True, frozen=True)
 | 
			
		||||
class CasResponse:
 | 
			
		||||
    username = attr.ib(type=str)
 | 
			
		||||
    attributes = attr.ib(type=Dict[str, Optional[str]])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CasHandler:
 | 
			
		||||
    """
 | 
			
		||||
    Utility class for to handle the response from a CAS SSO service.
 | 
			
		||||
| 
						 | 
				
			
			@ -50,6 +72,8 @@ class CasHandler:
 | 
			
		|||
 | 
			
		||||
        self._http_client = hs.get_proxied_http_client()
 | 
			
		||||
 | 
			
		||||
        self._sso_handler = hs.get_sso_handler()
 | 
			
		||||
 | 
			
		||||
    def _build_service_param(self, args: Dict[str, str]) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Generates a value to use as the "service" parameter when redirecting or
 | 
			
		||||
| 
						 | 
				
			
			@ -69,14 +93,20 @@ class CasHandler:
 | 
			
		|||
 | 
			
		||||
    async def _validate_ticket(
 | 
			
		||||
        self, ticket: str, service_args: Dict[str, str]
 | 
			
		||||
    ) -> Tuple[str, Optional[str]]:
 | 
			
		||||
    ) -> CasResponse:
 | 
			
		||||
        """
 | 
			
		||||
        Validate a CAS ticket with the server, parse the response, and return the user and display name.
 | 
			
		||||
        Validate a CAS ticket with the server, and return the parsed the response.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ticket: The CAS ticket from the client.
 | 
			
		||||
            service_args: Additional arguments to include in the service URL.
 | 
			
		||||
                Should be the same as those passed to `get_redirect_url`.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            CasError: If there's an error parsing the CAS response.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The parsed CAS response.
 | 
			
		||||
        """
 | 
			
		||||
        uri = self._cas_server_url + "/proxyValidate"
 | 
			
		||||
        args = {
 | 
			
		||||
| 
						 | 
				
			
			@ -89,43 +119,46 @@ class CasHandler:
 | 
			
		|||
            # Twisted raises this error if the connection is closed,
 | 
			
		||||
            # even if that's being used old-http style to signal end-of-data
 | 
			
		||||
            body = pde.response
 | 
			
		||||
        except HttpResponseException as e:
 | 
			
		||||
            description = (
 | 
			
		||||
                (
 | 
			
		||||
                    'Authorization server responded with a "{status}" error '
 | 
			
		||||
                    "while exchanging the authorization code."
 | 
			
		||||
                ).format(status=e.code),
 | 
			
		||||
            )
 | 
			
		||||
            raise CasError("server_error", description) from e
 | 
			
		||||
 | 
			
		||||
        user, attributes = self._parse_cas_response(body)
 | 
			
		||||
        displayname = attributes.pop(self._cas_displayname_attribute, None)
 | 
			
		||||
        return self._parse_cas_response(body)
 | 
			
		||||
 | 
			
		||||
        for required_attribute, required_value in self._cas_required_attributes.items():
 | 
			
		||||
            # If required attribute was not in CAS Response - Forbidden
 | 
			
		||||
            if required_attribute not in attributes:
 | 
			
		||||
                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 | 
			
		||||
 | 
			
		||||
            # Also need to check value
 | 
			
		||||
            if required_value is not None:
 | 
			
		||||
                actual_value = attributes[required_attribute]
 | 
			
		||||
                # If required attribute value does not match expected - Forbidden
 | 
			
		||||
                if required_value != actual_value:
 | 
			
		||||
                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 | 
			
		||||
 | 
			
		||||
        return user, displayname
 | 
			
		||||
 | 
			
		||||
    def _parse_cas_response(
 | 
			
		||||
        self, cas_response_body: bytes
 | 
			
		||||
    ) -> Tuple[str, Dict[str, Optional[str]]]:
 | 
			
		||||
    def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the user and other parameters from the CAS response.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            cas_response_body: The response from the CAS query.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            CasError: If there's an error parsing the CAS response.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of the user and a mapping of other attributes.
 | 
			
		||||
            The parsed CAS response.
 | 
			
		||||
        """
 | 
			
		||||
        user = None
 | 
			
		||||
        attributes = {}
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
        # Ensure the response is valid.
 | 
			
		||||
        root = ET.fromstring(cas_response_body)
 | 
			
		||||
        if not root.tag.endswith("serviceResponse"):
 | 
			
		||||
                raise Exception("root of CAS response is not serviceResponse")
 | 
			
		||||
            raise CasError(
 | 
			
		||||
                "missing_service_response",
 | 
			
		||||
                "root of CAS response is not serviceResponse",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        success = root[0].tag.endswith("authenticationSuccess")
 | 
			
		||||
        if not success:
 | 
			
		||||
            raise CasError("unsucessful_response", "Unsuccessful CAS response")
 | 
			
		||||
 | 
			
		||||
        # Iterate through the nodes and pull out the user and any extra attributes.
 | 
			
		||||
        user = None
 | 
			
		||||
        attributes = {}
 | 
			
		||||
        for child in root[0]:
 | 
			
		||||
            if child.tag.endswith("user"):
 | 
			
		||||
                user = child.text
 | 
			
		||||
| 
						 | 
				
			
			@ -139,16 +172,12 @@ class CasHandler:
 | 
			
		|||
                    if "}" in tag:
 | 
			
		||||
                        tag = tag.split("}")[1]
 | 
			
		||||
                    attributes[tag] = attribute.text
 | 
			
		||||
 | 
			
		||||
        # Ensure a user was found.
 | 
			
		||||
        if user is None:
 | 
			
		||||
                raise Exception("CAS response does not contain user")
 | 
			
		||||
        except Exception:
 | 
			
		||||
            logger.exception("Error parsing CAS response")
 | 
			
		||||
            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
 | 
			
		||||
        if not success:
 | 
			
		||||
            raise LoginError(
 | 
			
		||||
                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
 | 
			
		||||
            )
 | 
			
		||||
        return user, attributes
 | 
			
		||||
            raise CasError("no_user", "CAS response does not contain user")
 | 
			
		||||
 | 
			
		||||
        return CasResponse(user, attributes)
 | 
			
		||||
 | 
			
		||||
    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -201,7 +230,68 @@ class CasHandler:
 | 
			
		|||
            args["redirectUrl"] = client_redirect_url
 | 
			
		||||
        if session:
 | 
			
		||||
            args["session"] = session
 | 
			
		||||
        username, user_display_name = await self._validate_ticket(ticket, args)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            cas_response = await self._validate_ticket(ticket, args)
 | 
			
		||||
        except CasError as e:
 | 
			
		||||
            logger.exception("Could not validate ticket")
 | 
			
		||||
            self._sso_handler.render_error(request, e.error, e.error_description, 401)
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        await self._handle_cas_response(
 | 
			
		||||
            request, cas_response, client_redirect_url, session
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def _handle_cas_response(
 | 
			
		||||
        self,
 | 
			
		||||
        request: SynapseRequest,
 | 
			
		||||
        cas_response: CasResponse,
 | 
			
		||||
        client_redirect_url: Optional[str],
 | 
			
		||||
        session: Optional[str],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Handle a CAS response to a ticket request.
 | 
			
		||||
 | 
			
		||||
        Assumes that the response has been validated. Maps the user onto an MXID,
 | 
			
		||||
        registering them if necessary, and returns a response to the browser.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            request: the incoming request from the browser. We'll respond to it with an
 | 
			
		||||
                HTML page or a redirect
 | 
			
		||||
 | 
			
		||||
            cas_response: The parsed CAS response.
 | 
			
		||||
 | 
			
		||||
            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
 | 
			
		||||
                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
 | 
			
		||||
 | 
			
		||||
            session: The session parameter from the `/cas/ticket` HTTP request, if given.
 | 
			
		||||
                This should be the UI Auth session id.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Ensure that the attributes of the logged in user meet the required
 | 
			
		||||
        # attributes.
 | 
			
		||||
        for required_attribute, required_value in self._cas_required_attributes.items():
 | 
			
		||||
            # If required attribute was not in CAS Response - Forbidden
 | 
			
		||||
            if required_attribute not in cas_response.attributes:
 | 
			
		||||
                self._sso_handler.render_error(
 | 
			
		||||
                    request,
 | 
			
		||||
                    "unauthorised",
 | 
			
		||||
                    "You are not authorised to log in here.",
 | 
			
		||||
                    401,
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            # Also need to check value
 | 
			
		||||
            if required_value is not None:
 | 
			
		||||
                actual_value = cas_response.attributes[required_attribute]
 | 
			
		||||
                # If required attribute value does not match expected - Forbidden
 | 
			
		||||
                if required_value != actual_value:
 | 
			
		||||
                    self._sso_handler.render_error(
 | 
			
		||||
                        request,
 | 
			
		||||
                        "unauthorised",
 | 
			
		||||
                        "You are not authorised to log in here.",
 | 
			
		||||
                        401,
 | 
			
		||||
                    )
 | 
			
		||||
                    return
 | 
			
		||||
 | 
			
		||||
        # Pull out the user-agent and IP from the request.
 | 
			
		||||
        user_agent = request.get_user_agent("")
 | 
			
		||||
| 
						 | 
				
			
			@ -209,7 +299,7 @@ class CasHandler:
 | 
			
		|||
 | 
			
		||||
        # Get the matrix ID from the CAS username.
 | 
			
		||||
        user_id = await self._map_cas_user_to_matrix_user(
 | 
			
		||||
            username, user_display_name, user_agent, ip_address
 | 
			
		||||
            cas_response, user_agent, ip_address
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if session:
 | 
			
		||||
| 
						 | 
				
			
			@ -225,18 +315,13 @@ class CasHandler:
 | 
			
		|||
            )
 | 
			
		||||
 | 
			
		||||
    async def _map_cas_user_to_matrix_user(
 | 
			
		||||
        self,
 | 
			
		||||
        remote_user_id: str,
 | 
			
		||||
        display_name: Optional[str],
 | 
			
		||||
        user_agent: str,
 | 
			
		||||
        ip_address: str,
 | 
			
		||||
        self, cas_response: CasResponse, user_agent: str, ip_address: str,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        Given a CAS username, retrieve the user ID for it and possibly register the user.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            remote_user_id: The username from the CAS response.
 | 
			
		||||
            display_name: The display name from the CAS response.
 | 
			
		||||
            cas_response: The parsed CAS response.
 | 
			
		||||
            user_agent: The user agent of the client making the request.
 | 
			
		||||
            ip_address: The IP address of the client making the request.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -244,15 +329,17 @@ class CasHandler:
 | 
			
		|||
             The user ID associated with this response.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        localpart = map_username_to_mxid_localpart(remote_user_id)
 | 
			
		||||
        localpart = map_username_to_mxid_localpart(cas_response.username)
 | 
			
		||||
        user_id = UserID(localpart, self._hostname).to_string()
 | 
			
		||||
        registered_user_id = await self._auth_handler.check_user_exists(user_id)
 | 
			
		||||
 | 
			
		||||
        displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
 | 
			
		||||
 | 
			
		||||
        # If the user does not exist, register it.
 | 
			
		||||
        if not registered_user_id:
 | 
			
		||||
            registered_user_id = await self._registration_handler.register_user(
 | 
			
		||||
                localpart=localpart,
 | 
			
		||||
                default_display_name=display_name,
 | 
			
		||||
                default_display_name=displayname,
 | 
			
		||||
                user_agent_ips=[(user_agent, ip_address)],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,7 +101,11 @@ class SsoHandler:
 | 
			
		|||
        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 | 
			
		||||
 | 
			
		||||
    def render_error(
 | 
			
		||||
        self, request, error: str, error_description: Optional[str] = None
 | 
			
		||||
        self,
 | 
			
		||||
        request: Request,
 | 
			
		||||
        error: str,
 | 
			
		||||
        error_description: Optional[str] = None,
 | 
			
		||||
        code: int = 400,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Renders the error template and responds with it.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -113,11 +117,12 @@ class SsoHandler:
 | 
			
		|||
                We'll respond with an HTML page describing the error.
 | 
			
		||||
            error: A technical identifier for this error.
 | 
			
		||||
            error_description: A human-readable description of the error.
 | 
			
		||||
            code: The integer error code (an HTTP response code)
 | 
			
		||||
        """
 | 
			
		||||
        html = self._error_template.render(
 | 
			
		||||
            error=error, error_description=error_description
 | 
			
		||||
        )
 | 
			
		||||
        respond_with_html(request, 400, html)
 | 
			
		||||
        respond_with_html(request, code, html)
 | 
			
		||||
 | 
			
		||||
    async def get_sso_user_by_remote_user_id(
 | 
			
		||||
        self, auth_provider_id: str, remote_user_id: str
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue