Inherit from BaseHandler.

pull/8765/head
Patrick Cloke 2020-10-21 16:00:12 -04:00
parent 3dc1871219
commit e40bcf8e77
2 changed files with 24 additions and 30 deletions

View File

@ -34,6 +34,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -84,16 +85,15 @@ class OidcError(Exception):
class MappingException(Exception): class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object """Used to catch errors when mapping the SAML2 response to a user."""
"""
class OidcHandler: class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow. """Handles requests related to the OpenID Connect login flow.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str] self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -120,9 +120,6 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template self._error_template = hs.config.sso_error_template
@ -770,7 +767,7 @@ class OidcHandler:
macaroon.add_first_party_caveat( macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,) "ui_auth_session_id = %s" % (ui_auth_session_id,)
) )
now = self._clock.time_msec() now = self.clock.time_msec()
expiry = now + duration_in_ms expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
@ -845,7 +842,7 @@ class OidcHandler:
if not caveat.startswith(prefix): if not caveat.startswith(prefix):
return False return False
expiry = int(caveat[len(prefix) :]) expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec() now = self.clock.time_msec()
return now < expiry return now < expiry
async def _map_userinfo_to_user( async def _map_userinfo_to_user(
@ -891,7 +888,7 @@ class OidcHandler:
remote_user_id, remote_user_id,
) )
registered_user_id = await self._datastore.get_user_by_external_id( registered_user_id = await self.store.get_user_by_external_id(
self._auth_provider_id, remote_user_id, self._auth_provider_id, remote_user_id,
) )
@ -917,8 +914,8 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"]) localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self.server_name).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
if self._allow_existing_users: if self._allow_existing_users:
if len(users) == 1: if len(users) == 1:
@ -942,7 +939,8 @@ class OidcHandler:
default_display_name=attributes["display_name"], default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address), user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id, self._auth_provider_id, remote_user_id, registered_user_id,
) )
return registered_user_id return registered_user_id

View File

@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -57,17 +58,13 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None) ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler: class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute hs.config.saml2_grandfathered_mxid_source_attribute
@ -88,7 +85,7 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings # a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error( def _render_error(
self, request, error: str, error_description: Optional[str] = None self, request, error: str, error_description: Optional[str] = None
@ -130,7 +127,7 @@ class SamlHandler:
# Since SAML sessions timeout it is useful to log when they were created. # Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,)) logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec() now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData( self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id, creation_time=now, ui_auth_session_id=ui_auth_session_id,
) )
@ -279,7 +276,7 @@ class SamlHandler:
self._auth_provider_id, self._auth_provider_id,
remote_user_id, remote_user_id,
) )
registered_user_id = await self._datastore.get_user_by_external_id( registered_user_id = await self.store.get_user_by_external_id(
self._auth_provider_id, remote_user_id self._auth_provider_id, remote_user_id
) )
if registered_user_id is not None: if registered_user_id is not None:
@ -294,7 +291,7 @@ class SamlHandler:
): ):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID( user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname map_username_to_mxid_localpart(attrval), self.server_name
).to_string() ).to_string()
logger.info( logger.info(
"Looking for existing account based on mapped %s %s", "Looking for existing account based on mapped %s %s",
@ -302,11 +299,11 @@ class SamlHandler:
user_id, user_id,
) )
users = await self._datastore.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
registered_user_id = list(users.keys())[0] registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id) logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id( await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id self._auth_provider_id, remote_user_id, registered_user_id
) )
return registered_user_id return registered_user_id
@ -335,8 +332,8 @@ class SamlHandler:
emails = attribute_dict.get("emails", []) emails = attribute_dict.get("emails", [])
# Check if this mxid already exists # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive( if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string() UserID(localpart, self.server_name).to_string()
): ):
# This mxid is free # This mxid is free
break break
@ -348,7 +345,6 @@ class SamlHandler:
) )
logger.info("Mapped SAML user to local part %s", localpart) logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
@ -356,13 +352,13 @@ class SamlHandler:
user_agent_ips=(user_agent, ip_address), user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id( await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id self._auth_provider_id, remote_user_id, registered_user_id
) )
return registered_user_id return registered_user_id
def expire_sessions(self): def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set() to_expire = set()
for reqid, data in self._outstanding_requests_dict.items(): for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before: if data.creation_time < expire_before: