Abstract shared SSO code. (#8765)
De-duplicates code between the SAML and OIDC implementations.pull/8777/head
parent
e487d9fabc
commit
ee382025b0
|
@ -0,0 +1 @@
|
|||
Consolidate logic between the OpenID Connect and SAML code.
|
|
@ -34,7 +34,8 @@ from typing_extensions import TypedDict
|
|||
from twisted.web.client import readBody
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
|
@ -83,17 +84,12 @@ class OidcError(Exception):
|
|||
return self.error
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the UserInfo object
|
||||
"""
|
||||
|
||||
|
||||
class OidcHandler:
|
||||
class OidcHandler(BaseHandler):
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
super().__init__(hs)
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||
|
@ -120,36 +116,13 @@ class OidcHandler:
|
|||
self._http_client = hs.get_proxied_http_client()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._clock = hs.get_clock()
|
||||
self._hostname = hs.hostname # type: str
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
self._error_template = hs.config.sso_error_template
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "oidc"
|
||||
|
||||
def _render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Render the error template and respond to the request with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under `synapse/res/templates/sso_error.html`.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error. Those include
|
||||
well-known OAuth2/OIDC error types like invalid_request or
|
||||
access_denied.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
)
|
||||
respond_with_html(request, 400, html)
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
def _validate_metadata(self):
|
||||
"""Verifies the provider metadata.
|
||||
|
@ -571,7 +544,7 @@ class OidcHandler:
|
|||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
``self._render_error`` which displays an HTML page for the error.
|
||||
``self._sso_handler.render_error`` which displays an HTML page for the error.
|
||||
|
||||
Most of the OpenID Connect logic happens here:
|
||||
|
||||
|
@ -609,7 +582,7 @@ class OidcHandler:
|
|||
if error != "access_denied":
|
||||
logger.error("Error from the OIDC provider: %s %s", error, description)
|
||||
|
||||
self._render_error(request, error, description)
|
||||
self._sso_handler.render_error(request, error, description)
|
||||
return
|
||||
|
||||
# otherwise, it is presumably a successful response. see:
|
||||
|
@ -619,7 +592,9 @@ class OidcHandler:
|
|||
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
|
||||
if session is None:
|
||||
logger.info("No session cookie found")
|
||||
self._render_error(request, "missing_session", "No session cookie found")
|
||||
self._sso_handler.render_error(
|
||||
request, "missing_session", "No session cookie found"
|
||||
)
|
||||
return
|
||||
|
||||
# Remove the cookie. There is a good chance that if the callback failed
|
||||
|
@ -637,7 +612,9 @@ class OidcHandler:
|
|||
# Check for the state query parameter
|
||||
if b"state" not in request.args:
|
||||
logger.info("State parameter is missing")
|
||||
self._render_error(request, "invalid_request", "State parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "State parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
state = request.args[b"state"][0].decode()
|
||||
|
@ -651,17 +628,19 @@ class OidcHandler:
|
|||
) = self._verify_oidc_session_token(session, state)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._render_error(request, "invalid_session", str(e))
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.exception("Could not verify session")
|
||||
self._render_error(request, "mismatching_session", str(e))
|
||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
# Exchange the code with the provider
|
||||
if b"code" not in request.args:
|
||||
logger.info("Code parameter is missing")
|
||||
self._render_error(request, "invalid_request", "Code parameter is missing")
|
||||
self._sso_handler.render_error(
|
||||
request, "invalid_request", "Code parameter is missing"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("Exchanging code")
|
||||
|
@ -670,7 +649,7 @@ class OidcHandler:
|
|||
token = await self._exchange_code(code)
|
||||
except OidcError as e:
|
||||
logger.exception("Could not exchange code")
|
||||
self._render_error(request, e.error, e.error_description)
|
||||
self._sso_handler.render_error(request, e.error, e.error_description)
|
||||
return
|
||||
|
||||
logger.debug("Successfully obtained OAuth2 access token")
|
||||
|
@ -683,7 +662,7 @@ class OidcHandler:
|
|||
userinfo = await self._fetch_userinfo(token)
|
||||
except Exception as e:
|
||||
logger.exception("Could not fetch userinfo")
|
||||
self._render_error(request, "fetch_error", str(e))
|
||||
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||
return
|
||||
else:
|
||||
logger.debug("Extracting userinfo from id_token")
|
||||
|
@ -691,7 +670,7 @@ class OidcHandler:
|
|||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._render_error(request, "invalid_token", str(e))
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
|
@ -705,7 +684,7 @@ class OidcHandler:
|
|||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# Mapping providers might not have get_extra_attributes: only call this
|
||||
|
@ -770,7 +749,7 @@ class OidcHandler:
|
|||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
|
||||
|
@ -845,7 +824,7 @@ class OidcHandler:
|
|||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _map_userinfo_to_user(
|
||||
|
@ -885,20 +864,14 @@ class OidcHandler:
|
|||
# to be strings.
|
||||
remote_user_id = str(remote_user_id)
|
||||
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
# first of all, check if we already have a mapping for this user
|
||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
)
|
||||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
|
||||
# Otherwise, generate a new user.
|
||||
try:
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token
|
||||
|
@ -917,8 +890,8 @@ class OidcHandler:
|
|||
|
||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
user_id = UserID(localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
if self._allow_existing_users:
|
||||
if len(users) == 1:
|
||||
|
@ -942,7 +915,8 @@ class OidcHandler:
|
|||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
await self._datastore.record_user_external_id(
|
||||
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||
)
|
||||
return registered_user_id
|
||||
|
|
|
@ -24,7 +24,8 @@ from saml2.client import Saml2Client
|
|||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.module_api import ModuleApi
|
||||
|
@ -42,10 +43,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the SAML2 response to a user."""
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
@ -57,17 +54,13 @@ class Saml2SessionData:
|
|||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
class SamlHandler:
|
||||
class SamlHandler(BaseHandler):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.hs = hs
|
||||
super().__init__(hs)
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self._datastore = hs.get_datastore()
|
||||
self._hostname = hs.hostname
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
|
@ -88,26 +81,9 @@ class SamlHandler:
|
|||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||
|
||||
# a lock on the mappings
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
|
||||
|
||||
def _render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Render the error template and respond to the request with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under `synapse/res/templates/sso_error.html`.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
)
|
||||
respond_with_html(request, 400, html)
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
def handle_redirect_request(
|
||||
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
|
||||
|
@ -130,7 +106,7 @@ class SamlHandler:
|
|||
# Since SAML sessions timeout it is useful to log when they were created.
|
||||
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
||||
|
||||
now = self._clock.time_msec()
|
||||
now = self.clock.time_msec()
|
||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
|
@ -171,12 +147,12 @@ class SamlHandler:
|
|||
# in the (user-visible) exception message, so let's log the exception here
|
||||
# so we can track down the session IDs later.
|
||||
logger.warning(str(e))
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unsolicited_response", "Unexpected SAML2 login."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request,
|
||||
"invalid_response",
|
||||
"Unable to parse SAML2 response: %s." % (e,),
|
||||
|
@ -184,7 +160,7 @@ class SamlHandler:
|
|||
return
|
||||
|
||||
if saml2_auth.not_signed:
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unsigned_respond", "SAML2 response was not signed."
|
||||
)
|
||||
return
|
||||
|
@ -210,7 +186,7 @@ class SamlHandler:
|
|||
# attributes.
|
||||
for requirement in self._saml2_attribute_requirements:
|
||||
if not _check_attribute_requirement(saml2_auth.ava, requirement):
|
||||
self._render_error(
|
||||
self._sso_handler.render_error(
|
||||
request, "unauthorised", "You are not authorised to log in here."
|
||||
)
|
||||
return
|
||||
|
@ -226,7 +202,7 @@ class SamlHandler:
|
|||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
return
|
||||
|
||||
# Complete the interactive auth session or the login.
|
||||
|
@ -274,17 +250,11 @@ class SamlHandler:
|
|||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
)
|
||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
||||
self._auth_provider_id, remote_user_id
|
||||
)
|
||||
if registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", registered_user_id)
|
||||
return registered_user_id
|
||||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
|
||||
# backwards-compatibility hack: see if there is an existing user with a
|
||||
# suitable mapping from the uid
|
||||
|
@ -294,7 +264,7 @@ class SamlHandler:
|
|||
):
|
||||
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
||||
user_id = UserID(
|
||||
map_username_to_mxid_localpart(attrval), self._hostname
|
||||
map_username_to_mxid_localpart(attrval), self.server_name
|
||||
).to_string()
|
||||
logger.info(
|
||||
"Looking for existing account based on mapped %s %s",
|
||||
|
@ -302,11 +272,11 @@ class SamlHandler:
|
|||
user_id,
|
||||
)
|
||||
|
||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
registered_user_id = list(users.keys())[0]
|
||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||
await self._datastore.record_user_external_id(
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
@ -335,8 +305,8 @@ class SamlHandler:
|
|||
emails = attribute_dict.get("emails", [])
|
||||
|
||||
# Check if this mxid already exists
|
||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self._hostname).to_string()
|
||||
if not await self.store.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self.server_name).to_string()
|
||||
):
|
||||
# This mxid is free
|
||||
break
|
||||
|
@ -348,7 +318,6 @@ class SamlHandler:
|
|||
)
|
||||
|
||||
logger.info("Mapped SAML user to local part %s", localpart)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=displayname,
|
||||
|
@ -356,13 +325,13 @@ class SamlHandler:
|
|||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
def expire_sessions(self):
|
||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||
to_expire = set()
|
||||
for reqid, data in self._outstanding_requests_dict.items():
|
||||
if data.creation_time < expire_before:
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.http.server import respond_with_html
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MappingException(Exception):
|
||||
"""Used to catch errors when mapping the UserInfo object
|
||||
"""
|
||||
|
||||
|
||||
class SsoHandler(BaseHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._error_template = hs.config.sso_error_template
|
||||
|
||||
def render_error(
|
||||
self, request, error: str, error_description: Optional[str] = None
|
||||
) -> None:
|
||||
"""Renders the error template and responds with it.
|
||||
|
||||
This is used to show errors to the user. The template of this page can
|
||||
be found under `synapse/res/templates/sso_error.html`.
|
||||
|
||||
Args:
|
||||
request: The incoming request from the browser.
|
||||
We'll respond with an HTML page describing the error.
|
||||
error: A technical identifier for this error.
|
||||
error_description: A human-readable description of the error.
|
||||
"""
|
||||
html = self._error_template.render(
|
||||
error=error, error_description=error_description
|
||||
)
|
||||
respond_with_html(request, 400, html)
|
||||
|
||||
async def get_sso_user_by_remote_user_id(
|
||||
self, auth_provider_id: str, remote_user_id: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Maps the user ID of a remote IdP to a mxid for a previously seen user.
|
||||
|
||||
If the user has not been seen yet, this will return None.
|
||||
|
||||
Args:
|
||||
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||
"oidc" or "saml".
|
||||
remote_user_id: The user ID according to the remote IdP. This might
|
||||
be an e-mail address, a GUID, or some other form. It must be
|
||||
unique and immutable.
|
||||
|
||||
Returns:
|
||||
The mxid of a previously seen user.
|
||||
"""
|
||||
# Check if we already have a mapping for this user.
|
||||
logger.info(
|
||||
"Looking for existing mapping for user %s:%s",
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
)
|
||||
previously_registered_user_id = await self.store.get_user_by_external_id(
|
||||
auth_provider_id, remote_user_id,
|
||||
)
|
||||
|
||||
# A match was found, return the user ID.
|
||||
if previously_registered_user_id is not None:
|
||||
logger.info("Found existing mapping %s", previously_registered_user_id)
|
||||
return previously_registered_user_id
|
||||
|
||||
# No match.
|
||||
return None
|
|
@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
|
|||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||
from synapse.handlers.search import SearchHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
from synapse.handlers.sso import SsoHandler
|
||||
from synapse.handlers.stats import StatsHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
||||
|
@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
else:
|
||||
return FollowerTypingHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_sso_handler(self) -> SsoHandler:
|
||||
return SsoHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_sync_handler(self) -> SyncHandler:
|
||||
return SyncHandler(self)
|
||||
|
|
|
@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.handler = OidcHandler(hs)
|
||||
# Mock the render error method.
|
||||
self.render_error = Mock(return_value=None)
|
||||
self.handler._sso_handler.render_error = self.render_error
|
||||
|
||||
return hs
|
||||
|
||||
|
@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
return patch.dict(self.handler._provider_metadata, values)
|
||||
|
||||
def assertRenderedError(self, error, error_description=None):
|
||||
args = self.handler._render_error.call_args[0]
|
||||
args = self.render_error.call_args[0]
|
||||
self.assertEqual(args[1], error)
|
||||
if error_description is not None:
|
||||
self.assertEqual(args[2], error_description)
|
||||
# Reset the render_error mock
|
||||
self.handler._render_error.reset_mock()
|
||||
self.render_error.reset_mock()
|
||||
|
||||
def test_config(self):
|
||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||
|
@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
def test_callback_error(self):
|
||||
"""Errors from the provider returned in the callback are displayed."""
|
||||
self.handler._render_error = Mock()
|
||||
request = Mock(args={})
|
||||
request.args[b"error"] = [b"invalid_client"]
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
"preferred_username": "bar",
|
||||
}
|
||||
user_id = "@foo:domain.org"
|
||||
self.handler._render_error = Mock(return_value=None)
|
||||
self.handler._exchange_code = simple_async_mock(return_value=token)
|
||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
|
@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_not_called()
|
||||
self.handler._render_error.assert_not_called()
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle mapping errors
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(
|
||||
|
@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||
self.handler._render_error.assert_not_called()
|
||||
self.render_error.assert_not_called()
|
||||
|
||||
# Handle userinfo fetching error
|
||||
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||
|
@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
def test_callback_session(self):
|
||||
"""The callback verifies the session presence and validity"""
|
||||
self.handler._render_error = Mock(return_value=None)
|
||||
request = Mock(spec=["args", "getCookie", "addCookie"])
|
||||
|
||||
# Missing cookie
|
||||
|
|
Loading…
Reference in New Issue