Extract OIDCProviderConfig object

Collect all the config options which related to an OIDC provider into a single
object.
pull/9105/head
Richard van der Hoff 2020-12-18 12:13:03 +00:00
parent 98a64b7f7f
commit 7cc9509eca
2 changed files with 140 additions and 62 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Type
import attr
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
@ -25,65 +31,29 @@ class OIDCConfig(Config):
section = "oidc"
def read_config(self, config, **kwargs):
self.oidc_enabled = False
self.oidc_provider = None # type: Optional[OidcProviderConfig]
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not oidc_config or not oidc_config.get("enabled", False):
if not self.oidc_provider:
return
try:
check_requirements("oidc")
except DependencyException as e:
raise ConfigError(e.message)
raise ConfigError(e.message) from e
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_enabled = True
self.oidc_discover = oidc_config.get("discover", True)
self.oidc_issuer = oidc_config["issuer"]
self.oidc_client_id = oidc_config["client_id"]
self.oidc_client_secret = oidc_config["client_secret"]
self.oidc_client_auth_method = oidc_config.get(
"client_auth_method", "client_secret_basic"
)
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.oidc_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
@property
def oidc_enabled(self) -> bool:
# OIDC is enabled if we have a provider
return bool(self.oidc_provider)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
@ -224,3 +194,108 @@ class OIDCConfig(Config):
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
Raises:
ConfigError if the configuration is malformed.
"""
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
ump_config, ("oidc_config", "user_mapping_provider")
)
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
return OidcProviderConfig(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
token_endpoint=oidc_config.get("token_endpoint"),
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
jwks_uri=oidc_config.get("jwks_uri"),
skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False),
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
)
@attr.s
class OidcProviderConfig:
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
issuer = attr.ib(type=str)
# oauth2 client id to use
client_id = attr.ib(type=str)
# oauth2 client secret to use
client_secret = attr.ib(type=str)
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
client_auth_method = attr.ib(type=str)
# list of scopes to request
scopes = attr.ib(type=Collection[str])
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str])
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str])
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint = attr.ib(type=Optional[str])
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri = attr.ib(type=Optional[str])
# Whether to skip metadata verification
skip_verification = attr.ib(type=bool)
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
user_profile_method = attr.ib(type=str)
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users = attr.ib(type=bool)
# the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type)
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()

View File

@ -94,27 +94,30 @@ class OidcHandler:
self._token_generator = OidcSessionTokenGenerator(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
provider = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider is not None
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
provider.client_id, provider.client_secret, provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri,
issuer=provider.issuer,
authorization_endpoint=provider.authorization_endpoint,
token_endpoint=provider.token_endpoint,
userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._provider_needs_discovery = provider.discover
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
)
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str