Add support for stable MSC2858 API (#9617)

The stable format uses different brand identifiers, so we need to support two
identifiers for each IdP.
release-v1.30.0
Richard van der Hoff 2021-03-16 11:21:26 +00:00 committed by GitHub
parent 5b5bc188cf
commit dd69110d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 88 additions and 28 deletions

1
changelog.d/9617.feature Normal file
View File

@ -0,0 +1 @@
Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)).

View File

@ -226,7 +226,7 @@ Synapse config:
oidc_providers: oidc_providers:
- idp_id: github - idp_id: github
idp_name: Github idp_name: Github
idp_brand: "org.matrix.github" # optional: styling hint for clients idp_brand: "github" # optional: styling hint for clients
discover: false discover: false
issuer: "https://github.com/" issuer: "https://github.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
@ -252,7 +252,7 @@ oidc_providers:
oidc_providers: oidc_providers:
- idp_id: google - idp_id: google
idp_name: Google idp_name: Google
idp_brand: "org.matrix.google" # optional: styling hint for clients idp_brand: "google" # optional: styling hint for clients
issuer: "https://accounts.google.com/" issuer: "https://accounts.google.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@ -299,7 +299,7 @@ Synapse config:
oidc_providers: oidc_providers:
- idp_id: gitlab - idp_id: gitlab
idp_name: Gitlab idp_name: Gitlab
idp_brand: "org.matrix.gitlab" # optional: styling hint for clients idp_brand: "gitlab" # optional: styling hint for clients
issuer: "https://gitlab.com/" issuer: "https://gitlab.com/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
client_secret: "your-client-secret" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED
@ -334,7 +334,7 @@ Synapse config:
```yaml ```yaml
- idp_id: facebook - idp_id: facebook
idp_name: Facebook idp_name: Facebook
idp_brand: "org.matrix.facebook" # optional: styling hint for clients idp_brand: "facebook" # optional: styling hint for clients
discover: false discover: false
issuer: "https://facebook.com" issuer: "https://facebook.com"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED

View File

@ -1919,7 +1919,7 @@ oidc_providers:
# #
#- idp_id: github #- idp_id: github
# idp_name: Github # idp_name: Github
# idp_brand: org.matrix.github # idp_brand: github
# discover: false # discover: false
# issuer: "https://github.com/" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED # client_id: "your-client-id" # TO BE FILLED

View File

@ -237,7 +237,7 @@ class OIDCConfig(Config):
# #
#- idp_id: github #- idp_id: github
# idp_name: Github # idp_name: Github
# idp_brand: org.matrix.github # idp_brand: github
# discover: false # discover: false
# issuer: "https://github.com/" # issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED # client_id: "your-client-id" # TO BE FILLED
@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"idp_icon": {"type": "string"}, "idp_icon": {"type": "string"},
"idp_brand": { "idp_brand": {
"type": "string", "type": "string",
# MSC2758-style namespaced identifier "minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
},
"idp_unstable_brand": {
"type": "string",
"minLength": 1, "minLength": 1,
"maxLength": 255, "maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$", "pattern": "^[a-z][a-z0-9_.-]*$",
@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
idp_name=oidc_config.get("idp_name", "OIDC"), idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon, idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"), idp_brand=oidc_config.get("idp_brand"),
unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True), discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"], issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"], client_id=oidc_config["client_id"],
@ -512,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP. # Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str]) idp_brand = attr.ib(type=Optional[str])
# Optional brand identifier for the unstable API (see MSC2858).
unstable_idp_brand = attr.ib(type=Optional[str])
# whether the OIDC discovery mechanism is used to discover endpoints # whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool) discover = attr.ib(type=bool)

View File

@ -83,6 +83,7 @@ class CasHandler:
# the SsoIdentityProvider protocol type. # the SsoIdentityProvider protocol type.
self.idp_icon = None self.idp_icon = None
self.idp_brand = None self.idp_brand = None
self.unstable_idp_brand = None
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()

View File

@ -330,6 +330,9 @@ class OidcProvider:
# optional brand identifier for this auth provider # optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand self.idp_brand = provider.idp_brand
# Optional brand identifier for the unstable API (see MSC2858).
self.unstable_idp_brand = provider.unstable_idp_brand
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)

View File

@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
# the SsoIdentityProvider protocol type. # the SsoIdentityProvider protocol type.
self.idp_icon = None self.idp_icon = None
self.idp_brand = None self.idp_brand = None
self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]

View File

@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
"""Optional branding identifier""" """Optional branding identifier"""
return None return None
@property
def unstable_idp_brand(self) -> Optional[str]:
"""Optional brand identifier for the unstable API (see MSC2858)."""
return None
@abc.abstractmethod @abc.abstractmethod
async def handle_redirect_request( async def handle_redirect_request(
self, self,

View File

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri from synapse.http import get_request_uri
@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict sso_flow = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
idp,
)
for idp in self._sso_handler.get_identity_providers().values()
],
} # type: JsonDict
if self._msc2858_enabled: if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
# support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [ sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp) _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values() for idp in self._sso_handler.get_identity_providers().values()
] ]
@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
return result return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: def _get_auth_flow_dict_for_idp(
idp: SsoIdentityProvider, use_unstable_brands: bool = False
) -> JsonDict:
"""Return an entry for the login flow dict """Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login response to GET /_matrix/client/r0/login
Args:
idp: the identity provider to describe
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
""" """
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon: if idp.idp_icon:
e["icon"] = idp.idp_icon e["icon"] = idp.idp_icon
if idp.idp_brand: if idp.idp_brand:
e["brand"] = idp.idp_brand e["brand"] = idp.idp_brand
# use the stable brand identifier if the unstable identifier isn't defined.
if use_unstable_brands and idp.unstable_idp_brand:
e["brand"] = idp.unstable_idp_brand
return e return e
class SsoRedirectServlet(RestServlet): class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
re.compile(
"^"
+ CLIENT_API_PREFIX
+ "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
)
]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they # make sure that the relevant handlers are instantiated, so that they
@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
super().register(http_server) super().register(http_server)
if self._msc2858_enabled: if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support # expose additional endpoint for MSC2858 support: backwards-compat support
# for clients which don't yet support the stable endpoints.
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_patterns( client_patterns(

View File

@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
expected_flows = [ expected_flow_types = [
{"type": "m.login.cas"}, "m.login.cas",
{"type": "m.login.sso"}, "m.login.sso",
{"type": "m.login.token"}, "m.login.token",
{"type": "m.login.password"}, "m.login.password",
] + ADDITIONAL_LOGIN_FLOWS ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
self.assertCountEqual(channel.json_body["flows"], expected_flows) self.assertCountEqual(
[f["type"] for f in channel.json_body["flows"]], expected_flow_types
)
@override_config({"experimental_features": {"msc2858_enabled": True}}) @override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self): def test_get_msc2858_login_flows(self):
@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self): def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request(True, "xxx") channel = self._make_sso_redirect_request(False, "xxx")
self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self): def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request(False, "oidc")
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_msc2858_redirect_to_oidc(self):
"""Test the unstable API"""
channel = self._make_sso_redirect_request(True, "oidc") channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri = channel.headers.getRawHeaders("Location")[0]
@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# it should redirect us to the auth page of the OIDC server # it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
channel = self._make_sso_redirect_request(True, "oidc")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def _make_sso_redirect_request( def _make_sso_redirect_request(
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
): ):