Inherit from BaseHandler.
parent
3dc1871219
commit
e40bcf8e77
|
@ -34,6 +34,7 @@ from typing_extensions import TypedDict
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -84,16 +85,15 @@ class OidcError(Exception):
|
||||||
|
|
||||||
|
|
||||||
class MappingException(Exception):
|
class MappingException(Exception):
|
||||||
"""Used to catch errors when mapping the UserInfo object
|
"""Used to catch errors when mapping the SAML2 response to a user."""
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class OidcHandler:
|
class OidcHandler(BaseHandler):
|
||||||
"""Handles requests related to the OpenID Connect login flow.
|
"""Handles requests related to the OpenID Connect login flow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
super().__init__(hs)
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||||
|
@ -120,9 +120,6 @@ class OidcHandler:
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._datastore = hs.get_datastore()
|
|
||||||
self._clock = hs.get_clock()
|
|
||||||
self._hostname = hs.hostname # type: str
|
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name = hs.config.server_name # type: str
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
self._error_template = hs.config.sso_error_template
|
self._error_template = hs.config.sso_error_template
|
||||||
|
@ -770,7 +767,7 @@ class OidcHandler:
|
||||||
macaroon.add_first_party_caveat(
|
macaroon.add_first_party_caveat(
|
||||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||||
)
|
)
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
|
||||||
|
@ -845,7 +842,7 @@ class OidcHandler:
|
||||||
if not caveat.startswith(prefix):
|
if not caveat.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
expiry = int(caveat[len(prefix) :])
|
expiry = int(caveat[len(prefix) :])
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _map_userinfo_to_user(
|
async def _map_userinfo_to_user(
|
||||||
|
@ -891,7 +888,7 @@ class OidcHandler:
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
registered_user_id = await self.store.get_user_by_external_id(
|
||||||
self._auth_provider_id, remote_user_id,
|
self._auth_provider_id, remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -917,8 +914,8 @@ class OidcHandler:
|
||||||
|
|
||||||
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
localpart = map_username_to_mxid_localpart(attributes["localpart"])
|
||||||
|
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self.server_name).to_string()
|
||||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
if self._allow_existing_users:
|
if self._allow_existing_users:
|
||||||
if len(users) == 1:
|
if len(users) == 1:
|
||||||
|
@ -942,7 +939,8 @@ class OidcHandler:
|
||||||
default_display_name=attributes["display_name"],
|
default_display_name=attributes["display_name"],
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
await self._datastore.record_user_external_id(
|
|
||||||
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
|
@ -24,6 +24,7 @@ from saml2.client import Saml2Client
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
from synapse.config.saml2_config import SamlAttributeRequirement
|
||||||
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
@ -57,17 +58,13 @@ class Saml2SessionData:
|
||||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||||
|
|
||||||
|
|
||||||
class SamlHandler:
|
class SamlHandler(BaseHandler):
|
||||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||||
self.hs = hs
|
super().__init__(hs)
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._auth = hs.get_auth()
|
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
|
||||||
self._datastore = hs.get_datastore()
|
|
||||||
self._hostname = hs.hostname
|
|
||||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||||
self._grandfathered_mxid_source_attribute = (
|
self._grandfathered_mxid_source_attribute = (
|
||||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||||
|
@ -88,7 +85,7 @@ class SamlHandler:
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
||||||
# a lock on the mappings
|
# a lock on the mappings
|
||||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
|
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
|
||||||
|
|
||||||
def _render_error(
|
def _render_error(
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
self, request, error: str, error_description: Optional[str] = None
|
||||||
|
@ -130,7 +127,7 @@ class SamlHandler:
|
||||||
# Since SAML sessions timeout it is useful to log when they were created.
|
# Since SAML sessions timeout it is useful to log when they were created.
|
||||||
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
logger.info("Initiating a new SAML session: %s" % (reqid,))
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
@ -279,7 +276,7 @@ class SamlHandler:
|
||||||
self._auth_provider_id,
|
self._auth_provider_id,
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
)
|
)
|
||||||
registered_user_id = await self._datastore.get_user_by_external_id(
|
registered_user_id = await self.store.get_user_by_external_id(
|
||||||
self._auth_provider_id, remote_user_id
|
self._auth_provider_id, remote_user_id
|
||||||
)
|
)
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
|
@ -294,7 +291,7 @@ class SamlHandler:
|
||||||
):
|
):
|
||||||
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
|
||||||
user_id = UserID(
|
user_id = UserID(
|
||||||
map_username_to_mxid_localpart(attrval), self._hostname
|
map_username_to_mxid_localpart(attrval), self.server_name
|
||||||
).to_string()
|
).to_string()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Looking for existing account based on mapped %s %s",
|
"Looking for existing account based on mapped %s %s",
|
||||||
|
@ -302,11 +299,11 @@ class SamlHandler:
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if users:
|
if users:
|
||||||
registered_user_id = list(users.keys())[0]
|
registered_user_id = list(users.keys())[0]
|
||||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||||
await self._datastore.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
self._auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
@ -335,8 +332,8 @@ class SamlHandler:
|
||||||
emails = attribute_dict.get("emails", [])
|
emails = attribute_dict.get("emails", [])
|
||||||
|
|
||||||
# Check if this mxid already exists
|
# Check if this mxid already exists
|
||||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
if not await self.store.get_users_by_id_case_insensitive(
|
||||||
UserID(localpart, self._hostname).to_string()
|
UserID(localpart, self.server_name).to_string()
|
||||||
):
|
):
|
||||||
# This mxid is free
|
# This mxid is free
|
||||||
break
|
break
|
||||||
|
@ -348,7 +345,6 @@ class SamlHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Mapped SAML user to local part %s", localpart)
|
logger.info("Mapped SAML user to local part %s", localpart)
|
||||||
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
default_display_name=displayname,
|
default_display_name=displayname,
|
||||||
|
@ -356,13 +352,13 @@ class SamlHandler:
|
||||||
user_agent_ips=(user_agent, ip_address),
|
user_agent_ips=(user_agent, ip_address),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._datastore.record_user_external_id(
|
await self.store.record_user_external_id(
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
self._auth_provider_id, remote_user_id, registered_user_id
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||||
to_expire = set()
|
to_expire = set()
|
||||||
for reqid, data in self._outstanding_requests_dict.items():
|
for reqid, data in self._outstanding_requests_dict.items():
|
||||||
if data.creation_time < expire_before:
|
if data.creation_time < expire_before:
|
||||||
|
|
Loading…
Reference in New Issue