| 
							
							
							
						 |  |  | @ -1,22 +1,67 @@ | 
		
	
		
			
				|  |  |  |  | import json | 
		
	
		
			
				|  |  |  |  | # -*- coding: utf-8 -*- | 
		
	
		
			
				|  |  |  |  | # Copyright 2019-2021 The Matrix.org Foundation C.I.C. | 
		
	
		
			
				|  |  |  |  | # | 
		
	
		
			
				|  |  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
		
	
		
			
				|  |  |  |  | # you may not use this file except in compliance with the License. | 
		
	
		
			
				|  |  |  |  | # You may obtain a copy of the License at | 
		
	
		
			
				|  |  |  |  | # | 
		
	
		
			
				|  |  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
		
	
		
			
				|  |  |  |  | # | 
		
	
		
			
				|  |  |  |  | # Unless required by applicable law or agreed to in writing, software | 
		
	
		
			
				|  |  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
		
	
		
			
				|  |  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
		
	
		
			
				|  |  |  |  | # See the License for the specific language governing permissions and | 
		
	
		
			
				|  |  |  |  | # limitations under the License. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | import time | 
		
	
		
			
				|  |  |  |  | import urllib.parse | 
		
	
		
			
				|  |  |  |  | from html.parser import HTMLParser | 
		
	
		
			
				|  |  |  |  | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | from mock import Mock | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | try: | 
		
	
		
			
				|  |  |  |  |     import jwt | 
		
	
		
			
				|  |  |  |  | except ImportError: | 
		
	
		
			
				|  |  |  |  |     jwt = None | 
		
	
		
			
				|  |  |  |  | import pymacaroons | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | from twisted.web.resource import Resource | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | import synapse.rest.admin | 
		
	
		
			
				|  |  |  |  | from synapse.appservice import ApplicationService | 
		
	
		
			
				|  |  |  |  | from synapse.rest.client.v1 import login, logout | 
		
	
		
			
				|  |  |  |  | from synapse.rest.client.v2_alpha import devices, register | 
		
	
		
			
				|  |  |  |  | from synapse.rest.client.v2_alpha.account import WhoamiRestServlet | 
		
	
		
			
				|  |  |  |  | from synapse.rest.synapse.client.pick_idp import PickIdpResource | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | from tests import unittest | 
		
	
		
			
				|  |  |  |  | from tests.unittest import override_config | 
		
	
		
			
				|  |  |  |  | from tests.handlers.test_oidc import HAS_OIDC | 
		
	
		
			
				|  |  |  |  | from tests.handlers.test_saml import has_saml2 | 
		
	
		
			
				|  |  |  |  | from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG | 
		
	
		
			
				|  |  |  |  | from tests.unittest import override_config, skip_unless | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | try: | 
		
	
		
			
				|  |  |  |  |     import jwt | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     HAS_JWT = True | 
		
	
		
			
				|  |  |  |  | except ImportError: | 
		
	
		
			
				|  |  |  |  |     HAS_JWT = False | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | # public_base_url used in some tests | 
		
	
		
			
				|  |  |  |  | BASE_URL = "https://synapse/" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | # CAS server used in some tests | 
		
	
		
			
				|  |  |  |  | CAS_SERVER = "https://fake.test" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | # just enough to tell pysaml2 where to redirect to | 
		
	
		
			
				|  |  |  |  | SAML_SERVER = "https://test.saml.server/idp/sso" | 
		
	
		
			
				|  |  |  |  | TEST_SAML_METADATA = """ | 
		
	
		
			
				|  |  |  |  | <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"> | 
		
	
		
			
				|  |  |  |  |   <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"> | 
		
	
		
			
				|  |  |  |  |       <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/> | 
		
	
		
			
				|  |  |  |  |   </md:IDPSSODescriptor> | 
		
	
		
			
				|  |  |  |  | </md:EntityDescriptor> | 
		
	
		
			
				|  |  |  |  | """ % { | 
		
	
		
			
				|  |  |  |  |     "SAML_SERVER": SAML_SERVER, | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | LOGIN_URL = b"/_matrix/client/r0/login" | 
		
	
		
			
				|  |  |  |  | TEST_URL = b"/_matrix/client/r0/account/whoami" | 
		
	
	
		
			
				
					|  |  |  | @ -314,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.assertEquals(channel.result["code"], b"200", channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") | 
		
	
		
			
				|  |  |  |  | class MultiSSOTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |     """Tests for homeservers with multiple SSO providers enabled""" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     servlets = [ | 
		
	
		
			
				|  |  |  |  |         login.register_servlets, | 
		
	
		
			
				|  |  |  |  |     ] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def default_config(self) -> Dict[str, Any]: | 
		
	
		
			
				|  |  |  |  |         config = super().default_config() | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         config["public_baseurl"] = BASE_URL | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         config["cas_config"] = { | 
		
	
		
			
				|  |  |  |  |             "enabled": True, | 
		
	
		
			
				|  |  |  |  |             "server_url": CAS_SERVER, | 
		
	
		
			
				|  |  |  |  |             "service_url": "https://matrix.goodserver.com:8448", | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         config["saml2_config"] = { | 
		
	
		
			
				|  |  |  |  |             "sp_config": { | 
		
	
		
			
				|  |  |  |  |                 "metadata": {"inline": [TEST_SAML_METADATA]}, | 
		
	
		
			
				|  |  |  |  |                 # use the XMLSecurity backend to avoid relying on xmlsec1 | 
		
	
		
			
				|  |  |  |  |                 "crypto_backend": "XMLSecurity", | 
		
	
		
			
				|  |  |  |  |             }, | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         config["oidc_config"] = TEST_OIDC_CONFIG | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         return config | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def create_resource_dict(self) -> Dict[str, Resource]: | 
		
	
		
			
				|  |  |  |  |         d = super().create_resource_dict() | 
		
	
		
			
				|  |  |  |  |         d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) | 
		
	
		
			
				|  |  |  |  |         return d | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_multi_sso_redirect(self): | 
		
	
		
			
				|  |  |  |  |         """/login/sso/redirect should redirect to an identity picker""" | 
		
	
		
			
				|  |  |  |  |         client_redirect_url = "https://x?<abc>" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # first hit the redirect url, which should redirect to our idp picker | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request( | 
		
	
		
			
				|  |  |  |  |             "GET", | 
		
	
		
			
				|  |  |  |  |             "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 302, channel.result) | 
		
	
		
			
				|  |  |  |  |         uri = channel.headers.getRawHeaders("Location")[0] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # hitting that picker should give us some HTML | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request("GET", uri) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 200, channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # parse the form to check it has fields assumed elsewhere in this class | 
		
	
		
			
				|  |  |  |  |         class FormPageParser(HTMLParser): | 
		
	
		
			
				|  |  |  |  |             def __init__(self): | 
		
	
		
			
				|  |  |  |  |                 super().__init__() | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |                 # the values of the hidden inputs: map from name to value | 
		
	
		
			
				|  |  |  |  |                 self.hiddens = {}  # type: Dict[str, Optional[str]] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |                 # the values of the radio buttons | 
		
	
		
			
				|  |  |  |  |                 self.radios = []  # type: List[Optional[str]] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             def handle_starttag( | 
		
	
		
			
				|  |  |  |  |                 self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] | 
		
	
		
			
				|  |  |  |  |             ) -> None: | 
		
	
		
			
				|  |  |  |  |                 attr_dict = dict(attrs) | 
		
	
		
			
				|  |  |  |  |                 if tag == "input": | 
		
	
		
			
				|  |  |  |  |                     if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": | 
		
	
		
			
				|  |  |  |  |                         self.radios.append(attr_dict["value"]) | 
		
	
		
			
				|  |  |  |  |                     elif attr_dict["type"] == "hidden": | 
		
	
		
			
				|  |  |  |  |                         input_name = attr_dict["name"] | 
		
	
		
			
				|  |  |  |  |                         assert input_name | 
		
	
		
			
				|  |  |  |  |                         self.hiddens[input_name] = attr_dict["value"] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             def error(_, message): | 
		
	
		
			
				|  |  |  |  |                 self.fail(message) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         p = FormPageParser() | 
		
	
		
			
				|  |  |  |  |         p.feed(channel.result["body"].decode("utf-8")) | 
		
	
		
			
				|  |  |  |  |         p.close() | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_multi_sso_redirect_to_cas(self): | 
		
	
		
			
				|  |  |  |  |         """If CAS is chosen, should redirect to the CAS server""" | 
		
	
		
			
				|  |  |  |  |         client_redirect_url = "https://x?<abc>" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request( | 
		
	
		
			
				|  |  |  |  |             "GET", | 
		
	
		
			
				|  |  |  |  |             "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", | 
		
	
		
			
				|  |  |  |  |             shorthand=False, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 302, channel.result) | 
		
	
		
			
				|  |  |  |  |         cas_uri = channel.headers.getRawHeaders("Location")[0] | 
		
	
		
			
				|  |  |  |  |         cas_uri_path, cas_uri_query = cas_uri.split("?", 1) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # it should redirect us to the login page of the cas server | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(cas_uri_path, CAS_SERVER + "/login") | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # check that the redirectUrl is correctly encoded in the service param - ie, the | 
		
	
		
			
				|  |  |  |  |         # place that CAS will redirect to | 
		
	
		
			
				|  |  |  |  |         cas_uri_params = urllib.parse.parse_qs(cas_uri_query) | 
		
	
		
			
				|  |  |  |  |         service_uri = cas_uri_params["service"][0] | 
		
	
		
			
				|  |  |  |  |         _, service_uri_query = service_uri.split("?", 1) | 
		
	
		
			
				|  |  |  |  |         service_uri_params = urllib.parse.parse_qs(service_uri_query) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_multi_sso_redirect_to_saml(self): | 
		
	
		
			
				|  |  |  |  |         """If SAML is chosen, should redirect to the SAML server""" | 
		
	
		
			
				|  |  |  |  |         client_redirect_url = "https://x?<abc>" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request( | 
		
	
		
			
				|  |  |  |  |             "GET", | 
		
	
		
			
				|  |  |  |  |             "/_synapse/client/pick_idp?redirectUrl=" | 
		
	
		
			
				|  |  |  |  |             + client_redirect_url | 
		
	
		
			
				|  |  |  |  |             + "&idp=saml", | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 302, channel.result) | 
		
	
		
			
				|  |  |  |  |         saml_uri = channel.headers.getRawHeaders("Location")[0] | 
		
	
		
			
				|  |  |  |  |         saml_uri_path, saml_uri_query = saml_uri.split("?", 1) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # it should redirect us to the login page of the SAML server | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(saml_uri_path, SAML_SERVER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # the RelayState is used to carry the client redirect url | 
		
	
		
			
				|  |  |  |  |         saml_uri_params = urllib.parse.parse_qs(saml_uri_query) | 
		
	
		
			
				|  |  |  |  |         relay_state_param = saml_uri_params["RelayState"][0] | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(relay_state_param, client_redirect_url) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_multi_sso_redirect_to_oidc(self): | 
		
	
		
			
				|  |  |  |  |         """If OIDC is chosen, should redirect to the OIDC auth endpoint""" | 
		
	
		
			
				|  |  |  |  |         client_redirect_url = "https://x?<abc>" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request( | 
		
	
		
			
				|  |  |  |  |             "GET", | 
		
	
		
			
				|  |  |  |  |             "/_synapse/client/pick_idp?redirectUrl=" | 
		
	
		
			
				|  |  |  |  |             + client_redirect_url | 
		
	
		
			
				|  |  |  |  |             + "&idp=oidc", | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 302, channel.result) | 
		
	
		
			
				|  |  |  |  |         oidc_uri = channel.headers.getRawHeaders("Location")[0] | 
		
	
		
			
				|  |  |  |  |         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # it should redirect us to the auth page of the OIDC server | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # ... and should have set a cookie including the redirect url | 
		
	
		
			
				|  |  |  |  |         cookies = dict( | 
		
	
		
			
				|  |  |  |  |             h.split(";")[0].split("=", maxsplit=1) | 
		
	
		
			
				|  |  |  |  |             for h in channel.headers.getRawHeaders("Set-Cookie") | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         oidc_session_cookie = cookies["oidc_session"] | 
		
	
		
			
				|  |  |  |  |         macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual( | 
		
	
		
			
				|  |  |  |  |             self._get_value_from_macaroon(macaroon, "client_redirect_url"), | 
		
	
		
			
				|  |  |  |  |             client_redirect_url, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_multi_sso_redirect_to_unknown(self): | 
		
	
		
			
				|  |  |  |  |         """An unknown IdP should cause a 400""" | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request( | 
		
	
		
			
				|  |  |  |  |             "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.code, 400, channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     @staticmethod | 
		
	
		
			
				|  |  |  |  |     def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: | 
		
	
		
			
				|  |  |  |  |         prefix = key + " = " | 
		
	
		
			
				|  |  |  |  |         for caveat in macaroon.caveats: | 
		
	
		
			
				|  |  |  |  |             if caveat.caveat_id.startswith(prefix): | 
		
	
		
			
				|  |  |  |  |                 return caveat.caveat_id[len(prefix) :] | 
		
	
		
			
				|  |  |  |  |         raise ValueError("No %s caveat in macaroon" % (key,)) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | class CASTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     servlets = [ | 
		
	
	
		
			
				
					|  |  |  | @ -327,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         config = self.default_config() | 
		
	
		
			
				|  |  |  |  |         config["cas_config"] = { | 
		
	
		
			
				|  |  |  |  |             "enabled": True, | 
		
	
		
			
				|  |  |  |  |             "server_url": "https://fake.test", | 
		
	
		
			
				|  |  |  |  |             "server_url": CAS_SERVER, | 
		
	
		
			
				|  |  |  |  |             "service_url": "https://matrix.goodserver.com:8448", | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | @ -413,8 +636,7 @@ class CASTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  |     ) | 
		
	
		
			
				|  |  |  |  |     def test_cas_redirect_whitelisted(self): | 
		
	
		
			
				|  |  |  |  |         """Tests that the SSO login flow serves a redirect to a whitelisted url | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         """Tests that the SSO login flow serves a redirect to a whitelisted url""" | 
		
	
		
			
				|  |  |  |  |         self._test_redirect("https://legit-site.com/") | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     @override_config({"public_baseurl": "https://example.com"}) | 
		
	
	
		
			
				
					|  |  |  | @ -462,10 +684,8 @@ class CASTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.assertIn(b"SSO account deactivated", channel.result["body"]) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | @skip_unless(HAS_JWT, "requires jwt") | 
		
	
		
			
				|  |  |  |  | class JWTTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |     if not jwt: | 
		
	
		
			
				|  |  |  |  |         skip = "requires jwt" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     servlets = [ | 
		
	
		
			
				|  |  |  |  |         synapse.rest.admin.register_servlets_for_client_rest_resource, | 
		
	
		
			
				|  |  |  |  |         login.register_servlets, | 
		
	
	
		
			
				
					|  |  |  | @ -481,17 +701,17 @@ class JWTTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.hs.config.jwt_algorithm = self.jwt_algorithm | 
		
	
		
			
				|  |  |  |  |         return self.hs | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: | 
		
	
		
			
				|  |  |  |  |     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: | 
		
	
		
			
				|  |  |  |  |         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. | 
		
	
		
			
				|  |  |  |  |         result = jwt.encode(token, secret, self.jwt_algorithm) | 
		
	
		
			
				|  |  |  |  |         result = jwt.encode( | 
		
	
		
			
				|  |  |  |  |             payload, secret, self.jwt_algorithm | 
		
	
		
			
				|  |  |  |  |         )  # type: Union[str, bytes] | 
		
	
		
			
				|  |  |  |  |         if isinstance(result, bytes): | 
		
	
		
			
				|  |  |  |  |             return result.decode("ascii") | 
		
	
		
			
				|  |  |  |  |         return result | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def jwt_login(self, *args): | 
		
	
		
			
				|  |  |  |  |         params = json.dumps( | 
		
	
		
			
				|  |  |  |  |             {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request(b"POST", LOGIN_URL, params) | 
		
	
		
			
				|  |  |  |  |         return channel | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | @ -623,7 +843,7 @@ class JWTTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_no_token(self): | 
		
	
		
			
				|  |  |  |  |         params = json.dumps({"type": "org.matrix.login.jwt"}) | 
		
	
		
			
				|  |  |  |  |         params = {"type": "org.matrix.login.jwt"} | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request(b"POST", LOGIN_URL, params) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.result["code"], b"403", channel.result) | 
		
	
		
			
				|  |  |  |  |         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | 
		
	
	
		
			
				
					|  |  |  | @ -633,10 +853,8 @@ class JWTTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  | # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use | 
		
	
		
			
				|  |  |  |  | # RSS256, with a public key configured in synapse as "jwt_secret", and tokens | 
		
	
		
			
				|  |  |  |  | # signed by the private key. | 
		
	
		
			
				|  |  |  |  | @skip_unless(HAS_JWT, "requires jwt") | 
		
	
		
			
				|  |  |  |  | class JWTPubKeyTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |     if not jwt: | 
		
	
		
			
				|  |  |  |  |         skip = "requires jwt" | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     servlets = [ | 
		
	
		
			
				|  |  |  |  |         login.register_servlets, | 
		
	
		
			
				|  |  |  |  |     ] | 
		
	
	
		
			
				
					|  |  |  | @ -693,17 +911,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.hs.config.jwt_algorithm = "RS256" | 
		
	
		
			
				|  |  |  |  |         return self.hs | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: | 
		
	
		
			
				|  |  |  |  |     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: | 
		
	
		
			
				|  |  |  |  |         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. | 
		
	
		
			
				|  |  |  |  |         result = jwt.encode(token, secret, "RS256") | 
		
	
		
			
				|  |  |  |  |         result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str] | 
		
	
		
			
				|  |  |  |  |         if isinstance(result, bytes): | 
		
	
		
			
				|  |  |  |  |             return result.decode("ascii") | 
		
	
		
			
				|  |  |  |  |         return result | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def jwt_login(self, *args): | 
		
	
		
			
				|  |  |  |  |         params = json.dumps( | 
		
	
		
			
				|  |  |  |  |             {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  |         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | 
		
	
		
			
				|  |  |  |  |         channel = self.make_request(b"POST", LOGIN_URL, params) | 
		
	
		
			
				|  |  |  |  |         return channel | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | @ -773,8 +989,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         return self.hs | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_appservice_user(self): | 
		
	
		
			
				|  |  |  |  |         """Test that an appservice user can use /login | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         """Test that an appservice user can use /login""" | 
		
	
		
			
				|  |  |  |  |         self.register_as_user(AS_USER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         params = { | 
		
	
	
		
			
				
					|  |  |  | @ -788,8 +1003,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.assertEquals(channel.result["code"], b"200", channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_appservice_user_bot(self): | 
		
	
		
			
				|  |  |  |  |         """Test that the appservice bot can use /login | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         """Test that the appservice bot can use /login""" | 
		
	
		
			
				|  |  |  |  |         self.register_as_user(AS_USER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         params = { | 
		
	
	
		
			
				
					|  |  |  | @ -803,8 +1017,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.assertEquals(channel.result["code"], b"200", channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_appservice_wrong_user(self): | 
		
	
		
			
				|  |  |  |  |         """Test that non-as users cannot login with the as token | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         """Test that non-as users cannot login with the as token""" | 
		
	
		
			
				|  |  |  |  |         self.register_as_user(AS_USER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         params = { | 
		
	
	
		
			
				
					|  |  |  | @ -818,8 +1031,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  |         self.assertEquals(channel.result["code"], b"403", channel.result) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_appservice_wrong_as(self): | 
		
	
		
			
				|  |  |  |  |         """Test that as users cannot login with wrong as token | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         """Test that as users cannot login with wrong as token""" | 
		
	
		
			
				|  |  |  |  |         self.register_as_user(AS_USER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         params = { | 
		
	
	
		
			
				
					|  |  |  | @ -834,7 +1046,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def test_login_appservice_no_token(self): | 
		
	
		
			
				|  |  |  |  |         """Test that users must provide a token when using the appservice | 
		
	
		
			
				|  |  |  |  |            login method | 
		
	
		
			
				|  |  |  |  |         login method | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         self.register_as_user(AS_USER) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | 
 |