Abstract rendering errors.

pull/8765/head
Patrick Cloke 2020-11-16 13:40:10 -05:00
parent e40bcf8e77
commit 5ad5b73e2a
5 changed files with 92 additions and 72 deletions

View File

@ -35,7 +35,7 @@ from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@ -84,10 +84,6 @@ class OidcError(Exception):
return self.error return self.error
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
class OidcHandler(BaseHandler): class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow. """Handles requests related to the OpenID Connect login flow.
""" """
@ -122,31 +118,11 @@ class OidcHandler(BaseHandler):
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "oidc" self._auth_provider_id = "oidc"
def _render_error( self._sso_handler = hs.get_sso_handler()
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def _validate_metadata(self): def _validate_metadata(self):
"""Verifies the provider metadata. """Verifies the provider metadata.
@ -568,7 +544,7 @@ class OidcHandler(BaseHandler):
Since we might want to display OIDC-related errors in a user-friendly Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error. ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here: Most of the OpenID Connect logic happens here:
@ -606,7 +582,7 @@ class OidcHandler(BaseHandler):
if error != "access_denied": if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description) logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description) self._sso_handler.render_error(request, error, description)
return return
# otherwise, it is presumably a successful response. see: # otherwise, it is presumably a successful response. see:
@ -616,7 +592,9 @@ class OidcHandler(BaseHandler):
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None: if session is None:
logger.info("No session cookie found") logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found") self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return return
# Remove the cookie. There is a good chance that if the callback failed # Remove the cookie. There is a good chance that if the callback failed
@ -634,7 +612,9 @@ class OidcHandler(BaseHandler):
# Check for the state query parameter # Check for the state query parameter
if b"state" not in request.args: if b"state" not in request.args:
logger.info("State parameter is missing") logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing") self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return return
state = request.args[b"state"][0].decode() state = request.args[b"state"][0].decode()
@ -648,17 +628,19 @@ class OidcHandler(BaseHandler):
) = self._verify_oidc_session_token(session, state) ) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e: except MacaroonDeserializationException as e:
logger.exception("Invalid session") logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e)) self._sso_handler.render_error(request, "invalid_session", str(e))
return return
except MacaroonInvalidSignatureException as e: except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session") logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e)) self._sso_handler.render_error(request, "mismatching_session", str(e))
return return
# Exchange the code with the provider # Exchange the code with the provider
if b"code" not in request.args: if b"code" not in request.args:
logger.info("Code parameter is missing") logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing") self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return return
logger.debug("Exchanging code") logger.debug("Exchanging code")
@ -667,7 +649,7 @@ class OidcHandler(BaseHandler):
token = await self._exchange_code(code) token = await self._exchange_code(code)
except OidcError as e: except OidcError as e:
logger.exception("Could not exchange code") logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description) self._sso_handler.render_error(request, e.error, e.error_description)
return return
logger.debug("Successfully obtained OAuth2 access token") logger.debug("Successfully obtained OAuth2 access token")
@ -680,7 +662,7 @@ class OidcHandler(BaseHandler):
userinfo = await self._fetch_userinfo(token) userinfo = await self._fetch_userinfo(token)
except Exception as e: except Exception as e:
logger.exception("Could not fetch userinfo") logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e)) self._sso_handler.render_error(request, "fetch_error", str(e))
return return
else: else:
logger.debug("Extracting userinfo from id_token") logger.debug("Extracting userinfo from id_token")
@ -688,7 +670,7 @@ class OidcHandler(BaseHandler):
userinfo = await self._parse_id_token(token, nonce=nonce) userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e: except Exception as e:
logger.exception("Invalid id_token") logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e)) self._sso_handler.render_error(request, "invalid_token", str(e))
return return
# Pull out the user-agent and IP from the request. # Pull out the user-agent and IP from the request.
@ -702,7 +684,7 @@ class OidcHandler(BaseHandler):
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return return
# Mapping providers might not have get_extra_attributes: only call this # Mapping providers might not have get_extra_attributes: only call this

View File

@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html from synapse.handlers.sso import MappingException
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
@ -43,10 +43,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s(slots=True) @attr.s(slots=True)
class Saml2SessionData: class Saml2SessionData:
"""Data we track about SAML2 sessions""" """Data we track about SAML2 sessions"""
@ -87,24 +83,7 @@ class SamlHandler(BaseHandler):
# a lock on the mappings # a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock) self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error( self._sso_handler = hs.get_sso_handler()
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def handle_redirect_request( def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@ -168,12 +147,12 @@ class SamlHandler(BaseHandler):
# in the (user-visible) exception message, so let's log the exception here # in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later. # so we can track down the session IDs later.
logger.warning(str(e)) logger.warning(str(e))
self._render_error( self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login." request, "unsolicited_response", "Unexpected SAML2 login."
) )
return return
except Exception as e: except Exception as e:
self._render_error( self._sso_handler.render_error(
request, request,
"invalid_response", "invalid_response",
"Unable to parse SAML2 response: %s." % (e,), "Unable to parse SAML2 response: %s." % (e,),
@ -181,7 +160,7 @@ class SamlHandler(BaseHandler):
return return
if saml2_auth.not_signed: if saml2_auth.not_signed:
self._render_error( self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed." request, "unsigned_respond", "SAML2 response was not signed."
) )
return return
@ -207,7 +186,7 @@ class SamlHandler(BaseHandler):
# attributes. # attributes.
for requirement in self._saml2_attribute_requirements: for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement): if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error( self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here." request, "unauthorised", "You are not authorised to log in here."
) )
return return
@ -223,7 +202,7 @@ class SamlHandler(BaseHandler):
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return return
# Complete the interactive auth session or the login. # Complete the interactive auth session or the login.

54
synapse/handlers/sso.py Normal file
View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class SsoHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._error_template = hs.config.sso_error_template
def render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)

View File

@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else: else:
return FollowerTypingHandler(self) return FollowerTypingHandler(self)
@cache_in_self
def get_sso_handler(self):
return SsoHandler(self)
@cache_in_self @cache_in_self
def get_sync_handler(self) -> SyncHandler: def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self) return SyncHandler(self)

View File

@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
self.handler = OidcHandler(hs) self.handler = OidcHandler(hs)
# Mock the render error method.
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
return hs return hs
@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values) return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None): def assertRenderedError(self, error, error_description=None):
args = self.handler._render_error.call_args[0] args = self.render_error.call_args[0]
self.assertEqual(args[1], error) self.assertEqual(args[1], error)
if error_description is not None: if error_description is not None:
self.assertEqual(args[2], error_description) self.assertEqual(args[2], error_description)
# Reset the render_error mock # Reset the render_error mock
self.handler._render_error.reset_mock() self.render_error.reset_mock()
def test_config(self): def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self): def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed.""" """Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={}) request = Mock(args={})
request.args[b"error"] = [b"invalid_client"] request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar", "preferred_username": "bar",
} }
user_id = "@foo:domain.org" user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address userinfo, token, user_agent, ip_address
) )
self.handler._fetch_userinfo.assert_not_called() self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called() self.render_error.assert_not_called()
# Handle mapping errors # Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock( self.handler._map_userinfo_to_user = simple_async_mock(
@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address userinfo, token, user_agent, ip_address
) )
self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called() self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self): def test_callback_session(self):
"""The callback verifies the session presence and validity""" """The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"]) request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie # Missing cookie