diff --git a/changelog.d/9036.feature b/changelog.d/9036.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9036.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/mypy.ini b/mypy.ini index 5d15b7bf1c..b996867121 100644 --- a/mypy.ini +++ b/mypy.ini @@ -103,6 +103,7 @@ files = tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, + tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_stream_change_cache.py diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ebc346105b..be938df962 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. if hs.config.cas_enabled: hs.get_cas_handler() - elif hs.config.saml2_enabled: + if hs.config.saml2_enabled: hs.get_saml_handler() - elif hs.config.oidc_enabled: + if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 901c72d36a..1d1dc9f8a2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,22 +1,67 @@ -import json +# -*- coding: utf-8 -*- +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time import urllib.parse +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from mock import Mock -try: - import jwt -except ImportError: - jwt = None +import pymacaroons + +from twisted.web.resource import Resource import synapse.rest.admin from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.synapse.client.pick_idp import PickIdpResource from tests import unittest -from tests.unittest import override_config +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.unittest import override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# public_base_url used in some tests +BASE_URL = "https://synapse/" + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ + + + + + +""" % { + "SAML_SERVER": SAML_SERVER, +} LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" @@ -314,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + config["oidc_config"] = TEST_OIDC_CONFIG + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + return d + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + client_redirect_url = "https://x?" + + # first hit the redirect url, which should redirect to our idp picker + channel = self.make_request( + "GET", + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + class FormPageParser(HTMLParser): + def __init__(self): + super().__init__() + + # the values of the hidden inputs: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of the radio buttons + self.radios = [] # type: List[Optional[str]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "input": + if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": + self.radios.append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + input_name = attr_dict["name"] + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + self.fail(message) + + p = FormPageParser() + p.feed(channel.result["body"].decode("utf-8")) + p.close() + + self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + + self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + cas_uri = channel.headers.getRawHeaders("Location")[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + saml_uri = channel.headers.getRawHeaders("Location")[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, client_redirect_url) + + def test_multi_sso_redirect_to_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookies = dict( + h.split(";")[0].split("=", maxsplit=1) + for h in channel.headers.getRawHeaders("Set-Cookie") + ) + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + client_redirect_url, + ) + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + class CASTestCase(unittest.HomeserverTestCase): servlets = [ @@ -327,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase): config = self.default_config() config["cas_config"] = { "enabled": True, - "server_url": "https://fake.test", + "server_url": CAS_SERVER, "service_url": "https://matrix.goodserver.com:8448", } @@ -413,8 +636,7 @@ class CASTestCase(unittest.HomeserverTestCase): } ) def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url - """ + """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) @@ -462,10 +684,8 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) +@skip_unless(HAS_JWT, "requires jwt") class JWTTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -481,17 +701,17 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, self.jwt_algorithm) + result = jwt.encode( + payload, secret, self.jwt_algorithm + ) # type: Union[str, bytes] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel @@ -623,7 +843,7 @@ class JWTTestCase(unittest.HomeserverTestCase): ) def test_login_no_token(self): - params = json.dumps({"type": "org.matrix.login.jwt"}) + params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -633,10 +853,8 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") class JWTPubKeyTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ login.register_servlets, ] @@ -693,17 +911,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, "RS256") + result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel @@ -773,8 +989,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs def test_login_appservice_user(self): - """Test that an appservice user can use /login - """ + """Test that an appservice user can use /login""" self.register_as_user(AS_USER) params = { @@ -788,8 +1003,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login - """ + """Test that the appservice bot can use /login""" self.register_as_user(AS_USER) params = { @@ -803,8 +1017,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token - """ + """Test that non-as users cannot login with the as token""" self.register_as_user(AS_USER) params = { @@ -818,8 +1031,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token - """ + """Test that as users cannot login with wrong as token""" self.register_as_user(AS_USER) params = { @@ -834,7 +1046,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_no_token(self): """Test that users must provide a token when using the appservice - login method + login method """ self.register_as_user(AS_USER) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index dbc27893b5..81b7f84360 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -444,6 +444,7 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = { "client_id": "test-client-id", "client_secret": "test-client-secret", "scopes": ["profile"], - "authorization_endpoint": "https://z", + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, "token_endpoint": "https://issuer.test/token", "userinfo_endpoint": "https://issuer.test/userinfo", "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},