Inherit from BaseHandler.

pull/8765/head
Patrick Cloke 2020-10-21 16:00:12 -04:00
parent 3dc1871219
commit e40bcf8e77
2 changed files with 24 additions and 30 deletions

View File

@ -34,6 +34,7 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@ -84,16 +85,15 @@ class OidcError(Exception):
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
"""Used to catch errors when mapping the SAML2 response to a user."""
class OidcHandler:
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -120,9 +120,6 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
@ -770,7 +767,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@ -845,7 +842,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@ -891,7 +888,7 @@ class OidcHandler:
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
registered_user_id = await self.store.get_user_by_external_id(
self._auth_provider_id, remote_user_id,
)
@ -917,8 +914,8 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
@ -942,7 +939,8 @@ class OidcHandler:
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id

View File

@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
@ -57,17 +58,13 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler:
class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@ -88,7 +85,7 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
@ -130,7 +127,7 @@ class SamlHandler:
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec()
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@ -279,7 +276,7 @@ class SamlHandler:
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
registered_user_id = await self.store.get_user_by_external_id(
self._auth_provider_id, remote_user_id
)
if registered_user_id is not None:
@ -294,7 +291,7 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname
map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
logger.info(
"Looking for existing account based on mapped %s %s",
@ -302,11 +299,11 @@ class SamlHandler:
user_id,
)
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
@ -335,8 +332,8 @@ class SamlHandler:
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
@ -348,7 +345,6 @@ class SamlHandler:
)
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
@ -356,13 +352,13 @@ class SamlHandler:
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before: