Add type hints to `tests/rest/client` (#12066)
							parent
							
								
									5b2b36809f
								
							
						
					
					
						commit
						64c73c6ac8
					
				|  | @ -0,0 +1 @@ | |||
| Add type hints to `tests/rest/client`. | ||||
|  | @ -13,17 +13,21 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from http import HTTPStatus | ||||
| from typing import Optional, Tuple, Union | ||||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||||
| 
 | ||||
| from twisted.internet.defer import succeed | ||||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| import synapse.rest.admin | ||||
| from synapse.api.constants import LoginType | ||||
| from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker | ||||
| from synapse.rest.client import account, auth, devices, login, logout, register | ||||
| from synapse.rest.synapse.client import build_synapse_client_resource_tree | ||||
| from synapse.server import HomeServer | ||||
| from synapse.storage.database import LoggingTransaction | ||||
| from synapse.types import JsonDict, UserID | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| from tests import unittest | ||||
| from tests.handlers.test_oidc import HAS_OIDC | ||||
|  | @ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless | |||
| 
 | ||||
| 
 | ||||
| class DummyRecaptchaChecker(UserInteractiveAuthChecker): | ||||
|     def __init__(self, hs): | ||||
|     def __init__(self, hs: HomeServer) -> None: | ||||
|         super().__init__(hs) | ||||
|         self.recaptcha_attempts = [] | ||||
|         self.recaptcha_attempts: List[Tuple[dict, str]] = [] | ||||
| 
 | ||||
|     def check_auth(self, authdict, clientip): | ||||
|     def check_auth(self, authdict: dict, clientip: str) -> Any: | ||||
|         self.recaptcha_attempts.append((authdict, clientip)) | ||||
|         return succeed(True) | ||||
| 
 | ||||
|  | @ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): | |||
|     ] | ||||
|     hijack_auth = False | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
| 
 | ||||
|         config = self.default_config() | ||||
| 
 | ||||
|  | @ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): | |||
|         hs = self.setup_test_homeserver(config=config) | ||||
|         return hs | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.recaptcha_checker = DummyRecaptchaChecker(hs) | ||||
|         auth_handler = hs.get_auth_handler() | ||||
|         auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker | ||||
|  | @ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): | |||
|         self.assertEqual(len(attempts), 1) | ||||
|         self.assertEqual(attempts[0][0]["response"], "a") | ||||
| 
 | ||||
|     def test_fallback_captcha(self): | ||||
|     def test_fallback_captcha(self) -> None: | ||||
|         """Ensure that fallback auth via a captcha works.""" | ||||
|         # Returns a 401 as per the spec | ||||
|         channel = self.register( | ||||
|  | @ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): | |||
|         # We're given a registered user. | ||||
|         self.assertEqual(channel.json_body["user_id"], "@user:test") | ||||
| 
 | ||||
|     def test_complete_operation_unknown_session(self): | ||||
|     def test_complete_operation_unknown_session(self) -> None: | ||||
|         """ | ||||
|         Attempting to mark an invalid session as complete should error. | ||||
|         """ | ||||
|  | @ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|         register.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def default_config(self): | ||||
|     def default_config(self) -> Dict[str, Any]: | ||||
|         config = super().default_config() | ||||
| 
 | ||||
|         # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns | ||||
|  | @ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         return config | ||||
| 
 | ||||
|     def create_resource_dict(self): | ||||
|     def create_resource_dict(self) -> Dict[str, Resource]: | ||||
|         resource_dict = super().create_resource_dict() | ||||
|         resource_dict.update(build_synapse_client_resource_tree(self.hs)) | ||||
|         return resource_dict | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.user_pass = "pass" | ||||
|         self.user = self.register_user("test", self.user_pass) | ||||
|         self.device_id = "dev1" | ||||
|  | @ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         return channel | ||||
| 
 | ||||
|     def test_ui_auth(self): | ||||
|     def test_ui_auth(self) -> None: | ||||
|         """ | ||||
|         Test user interactive authentication outside of registration. | ||||
|         """ | ||||
|  | @ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def test_grandfathered_identifier(self): | ||||
|     def test_grandfathered_identifier(self) -> None: | ||||
|         """Check behaviour without "identifier" dict | ||||
| 
 | ||||
|         Synapse used to require clients to submit a "user" field for m.login.password | ||||
|  | @ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def test_can_change_body(self): | ||||
|     def test_can_change_body(self) -> None: | ||||
|         """ | ||||
|         The client dict can be modified during the user interactive authentication session. | ||||
| 
 | ||||
|  | @ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|             }, | ||||
|         ) | ||||
| 
 | ||||
|     def test_cannot_change_uri(self): | ||||
|     def test_cannot_change_uri(self) -> None: | ||||
|         """ | ||||
|         The initial requested URI cannot be modified during the user interactive authentication session. | ||||
|         """ | ||||
|  | @ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|         ) | ||||
| 
 | ||||
|     @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) | ||||
|     def test_can_reuse_session(self): | ||||
|     def test_can_reuse_session(self) -> None: | ||||
|         """ | ||||
|         The session can be reused if configured. | ||||
| 
 | ||||
|  | @ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @skip_unless(HAS_OIDC, "requires OIDC") | ||||
|     @override_config({"oidc_config": TEST_OIDC_CONFIG}) | ||||
|     def test_ui_auth_via_sso(self): | ||||
|     def test_ui_auth_via_sso(self) -> None: | ||||
|         """Test a successful UI Auth flow via SSO | ||||
| 
 | ||||
|         This includes: | ||||
|  | @ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @skip_unless(HAS_OIDC, "requires OIDC") | ||||
|     @override_config({"oidc_config": TEST_OIDC_CONFIG}) | ||||
|     def test_does_not_offer_password_for_sso_user(self): | ||||
|     def test_does_not_offer_password_for_sso_user(self) -> None: | ||||
|         login_resp = self.helper.login_via_oidc("username") | ||||
|         user_tok = login_resp["access_token"] | ||||
|         device_id = login_resp["device_id"] | ||||
|  | @ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
|         flows = channel.json_body["flows"] | ||||
|         self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) | ||||
| 
 | ||||
|     def test_does_not_offer_sso_for_password_user(self): | ||||
|     def test_does_not_offer_sso_for_password_user(self) -> None: | ||||
|         channel = self.delete_device( | ||||
|             self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED | ||||
|         ) | ||||
|  | @ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @skip_unless(HAS_OIDC, "requires OIDC") | ||||
|     @override_config({"oidc_config": TEST_OIDC_CONFIG}) | ||||
|     def test_offers_both_flows_for_upgraded_user(self): | ||||
|     def test_offers_both_flows_for_upgraded_user(self) -> None: | ||||
|         """A user that had a password and then logged in with SSO should get both flows""" | ||||
|         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) | ||||
|         self.assertEqual(login_resp["user_id"], self.user) | ||||
|  | @ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @skip_unless(HAS_OIDC, "requires OIDC") | ||||
|     @override_config({"oidc_config": TEST_OIDC_CONFIG}) | ||||
|     def test_ui_auth_fails_for_incorrect_sso_user(self): | ||||
|     def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: | ||||
|         """If the user tries to authenticate with the wrong SSO user, they get an error""" | ||||
|         # log the user in | ||||
|         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) | ||||
|  | @ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|     ] | ||||
|     hijack_auth = False | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.user_pass = "pass" | ||||
|         self.user = self.register_user("test", self.user_pass) | ||||
| 
 | ||||
|  | @ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|             {"refresh_token": refresh_token}, | ||||
|         ) | ||||
| 
 | ||||
|     def is_access_token_valid(self, access_token) -> bool: | ||||
|     def is_access_token_valid(self, access_token: str) -> bool: | ||||
|         """ | ||||
|         Checks whether an access token is valid, returning whether it is or not. | ||||
|         """ | ||||
|  | @ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         return code == HTTPStatus.OK | ||||
| 
 | ||||
|     def test_login_issue_refresh_token(self): | ||||
|     def test_login_issue_refresh_token(self) -> None: | ||||
|         """ | ||||
|         A login response should include a refresh_token only if asked. | ||||
|         """ | ||||
|  | @ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|         self.assertIn("refresh_token", login_with_refresh.json_body) | ||||
|         self.assertIn("expires_in_ms", login_with_refresh.json_body) | ||||
| 
 | ||||
|     def test_register_issue_refresh_token(self): | ||||
|     def test_register_issue_refresh_token(self) -> None: | ||||
|         """ | ||||
|         A register response should include a refresh_token only if asked. | ||||
|         """ | ||||
|  | @ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|         self.assertIn("refresh_token", register_with_refresh.json_body) | ||||
|         self.assertIn("expires_in_ms", register_with_refresh.json_body) | ||||
| 
 | ||||
|     def test_token_refresh(self): | ||||
|     def test_token_refresh(self) -> None: | ||||
|         """ | ||||
|         A refresh token can be used to issue a new access token. | ||||
|         """ | ||||
|  | @ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|         ) | ||||
| 
 | ||||
|     @override_config({"refreshable_access_token_lifetime": "1m"}) | ||||
|     def test_refreshable_access_token_expiration(self): | ||||
|     def test_refreshable_access_token_expiration(self) -> None: | ||||
|         """ | ||||
|         The access token should have some time as specified in the config. | ||||
|         """ | ||||
|  | @ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|             "nonrefreshable_access_token_lifetime": "10m", | ||||
|         } | ||||
|     ) | ||||
|     def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): | ||||
|     def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens( | ||||
|         self, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Tests that the expiry times for refreshable and non-refreshable access | ||||
|         tokens can be different. | ||||
|  | @ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|     @override_config( | ||||
|         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} | ||||
|     ) | ||||
|     def test_refresh_token_expiry(self): | ||||
|     def test_refresh_token_expiry(self) -> None: | ||||
|         """ | ||||
|         The refresh token can be configured to have a limited lifetime. | ||||
|         When that lifetime has ended, the refresh token can no longer be used to | ||||
|  | @ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|             "session_lifetime": "3m", | ||||
|         } | ||||
|     ) | ||||
|     def test_ultimate_session_expiry(self): | ||||
|     def test_ultimate_session_expiry(self) -> None: | ||||
|         """ | ||||
|         The session can be configured to have an ultimate, limited lifetime. | ||||
|         """ | ||||
|  | @ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|             refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result | ||||
|         ) | ||||
| 
 | ||||
|     def test_refresh_token_invalidation(self): | ||||
|     def test_refresh_token_invalidation(self) -> None: | ||||
|         """Refresh tokens are invalidated after first use of the next token. | ||||
| 
 | ||||
|         A refresh token is considered invalid if: | ||||
|  | @ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): | |||
|             fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result | ||||
|         ) | ||||
| 
 | ||||
|     def test_many_token_refresh(self): | ||||
|     def test_many_token_refresh(self) -> None: | ||||
|         """ | ||||
|         If a refresh is performed many times during a session, there shouldn't be | ||||
|         extra 'cruft' built up over time. | ||||
|  |  | |||
|  | @ -13,9 +13,13 @@ | |||
| # limitations under the License. | ||||
| from http import HTTPStatus | ||||
| 
 | ||||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| 
 | ||||
| import synapse.rest.admin | ||||
| from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | ||||
| from synapse.rest.client import capabilities, login | ||||
| from synapse.server import HomeServer | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| from tests import unittest | ||||
| from tests.unittest import override_config | ||||
|  | @ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         login.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
|         self.url = b"/capabilities" | ||||
|         hs = self.setup_test_homeserver() | ||||
|         self.config = hs.config | ||||
|         self.auth_handler = hs.get_auth_handler() | ||||
|         return hs | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.localpart = "user" | ||||
|         self.password = "pass" | ||||
|         self.user = self.register_user(self.localpart, self.password) | ||||
| 
 | ||||
|     def test_check_auth_required(self): | ||||
|     def test_check_auth_required(self) -> None: | ||||
|         channel = self.make_request("GET", self.url) | ||||
| 
 | ||||
|         self.assertEqual(channel.code, 401) | ||||
| 
 | ||||
|     def test_get_room_version_capabilities(self): | ||||
|     def test_get_room_version_capabilities(self) -> None: | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|         channel = self.make_request("GET", self.url, access_token=access_token) | ||||
|  | @ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|             capabilities["m.room_versions"]["default"], | ||||
|         ) | ||||
| 
 | ||||
|     def test_get_change_password_capabilities_password_login(self): | ||||
|     def test_get_change_password_capabilities_password_login(self) -> None: | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|         channel = self.make_request("GET", self.url, access_token=access_token) | ||||
|  | @ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertTrue(capabilities["m.change_password"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"password_config": {"localdb_enabled": False}}) | ||||
|     def test_get_change_password_capabilities_localdb_disabled(self): | ||||
|     def test_get_change_password_capabilities_localdb_disabled(self) -> None: | ||||
|         access_token = self.get_success( | ||||
|             self.auth_handler.create_access_token_for_user_id( | ||||
|                 self.user, device_id=None, valid_until_ms=None | ||||
|  | @ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertFalse(capabilities["m.change_password"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"password_config": {"enabled": False}}) | ||||
|     def test_get_change_password_capabilities_password_disabled(self): | ||||
|     def test_get_change_password_capabilities_password_disabled(self) -> None: | ||||
|         access_token = self.get_success( | ||||
|             self.auth_handler.create_access_token_for_user_id( | ||||
|                 self.user, device_id=None, valid_until_ms=None | ||||
|  | @ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEqual(channel.code, 200) | ||||
|         self.assertFalse(capabilities["m.change_password"]["enabled"]) | ||||
| 
 | ||||
|     def test_get_change_users_attributes_capabilities(self): | ||||
|     def test_get_change_users_attributes_capabilities(self) -> None: | ||||
|         """Test that server returns capabilities by default.""" | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|  | @ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"enable_set_displayname": False}) | ||||
|     def test_get_set_displayname_capabilities_displayname_disabled(self): | ||||
|     def test_get_set_displayname_capabilities_displayname_disabled(self) -> None: | ||||
|         """Test if set displayname is disabled that the server responds it.""" | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|  | @ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertFalse(capabilities["m.set_displayname"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"enable_set_avatar_url": False}) | ||||
|     def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): | ||||
|     def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None: | ||||
|         """Test if set avatar_url is disabled that the server responds it.""" | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|  | @ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"enable_3pid_changes": False}) | ||||
|     def test_get_change_3pid_capabilities_3pid_disabled(self): | ||||
|     def test_get_change_3pid_capabilities_3pid_disabled(self) -> None: | ||||
|         """Test if change 3pid is disabled that the server responds it.""" | ||||
|         access_token = self.login(self.localpart, self.password) | ||||
| 
 | ||||
|  | @ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|         self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) | ||||
| 
 | ||||
|     @override_config({"experimental_features": {"msc3244_enabled": False}}) | ||||
|     def test_get_does_not_include_msc3244_fields_when_disabled(self): | ||||
|     def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None: | ||||
|         access_token = self.get_success( | ||||
|             self.auth_handler.create_access_token_for_user_id( | ||||
|                 self.user, device_id=None, valid_until_ms=None | ||||
|  | @ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): | |||
|             "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] | ||||
|         ) | ||||
| 
 | ||||
|     def test_get_does_include_msc3244_fields_when_enabled(self): | ||||
|     def test_get_does_include_msc3244_fields_when_enabled(self) -> None: | ||||
|         access_token = self.get_success( | ||||
|             self.auth_handler.create_access_token_for_user_id( | ||||
|                 self.user, device_id=None, valid_until_ms=None | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ from urllib.parse import urlencode | |||
| 
 | ||||
| import pymacaroons | ||||
| 
 | ||||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| import synapse.rest.admin | ||||
|  | @ -27,12 +28,15 @@ from synapse.appservice import ApplicationService | |||
| from synapse.rest.client import devices, login, logout, register | ||||
| from synapse.rest.client.account import WhoamiRestServlet | ||||
| from synapse.rest.synapse.client import build_synapse_client_resource_tree | ||||
| from synapse.server import HomeServer | ||||
| from synapse.types import create_requester | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| from tests import unittest | ||||
| from tests.handlers.test_oidc import HAS_OIDC | ||||
| from tests.handlers.test_saml import has_saml2 | ||||
| from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG | ||||
| from tests.server import FakeChannel | ||||
| from tests.test_utils.html_parsers import TestHtmlParser | ||||
| from tests.unittest import HomeserverTestCase, override_config, skip_unless | ||||
| 
 | ||||
|  | @ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), | ||||
|     ] | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
|         self.hs = self.setup_test_homeserver() | ||||
|         self.hs.config.registration.enable_registration = True | ||||
|         self.hs.config.registration.registrations_require_3pid = [] | ||||
|  | @ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|             } | ||||
|         } | ||||
|     ) | ||||
|     def test_POST_ratelimiting_per_address(self): | ||||
|     def test_POST_ratelimiting_per_address(self) -> None: | ||||
|         # Create different users so we're sure not to be bothered by the per-user | ||||
|         # ratelimiter. | ||||
|         for i in range(0, 6): | ||||
|  | @ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|             } | ||||
|         } | ||||
|     ) | ||||
|     def test_POST_ratelimiting_per_account(self): | ||||
|     def test_POST_ratelimiting_per_account(self) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
| 
 | ||||
|         for i in range(0, 6): | ||||
|  | @ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|             } | ||||
|         } | ||||
|     ) | ||||
|     def test_POST_ratelimiting_per_account_failed_attempts(self): | ||||
|     def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
| 
 | ||||
|         for i in range(0, 6): | ||||
|  | @ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEquals(channel.result["code"], b"403", channel.result) | ||||
| 
 | ||||
|     @override_config({"session_lifetime": "24h"}) | ||||
|     def test_soft_logout(self): | ||||
|     def test_soft_logout(self) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
| 
 | ||||
|         # we shouldn't be able to make requests without an access token | ||||
|  | @ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") | ||||
|         self.assertEquals(channel.json_body["soft_logout"], False) | ||||
| 
 | ||||
|     def _delete_device(self, access_token, user_id, password, device_id): | ||||
|     def _delete_device( | ||||
|         self, access_token: str, user_id: str, password: str, device_id: str | ||||
|     ) -> None: | ||||
|         """Perform the UI-Auth to delete a device""" | ||||
|         channel = self.make_request( | ||||
|             b"DELETE", "devices/" + device_id, access_token=access_token | ||||
|  | @ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEquals(channel.code, 200, channel.result) | ||||
| 
 | ||||
|     @override_config({"session_lifetime": "24h"}) | ||||
|     def test_session_can_hard_logout_after_being_soft_logged_out(self): | ||||
|     def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
| 
 | ||||
|         # log in as normal | ||||
|  | @ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEquals(channel.result["code"], b"200", channel.result) | ||||
| 
 | ||||
|     @override_config({"session_lifetime": "24h"}) | ||||
|     def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): | ||||
|     def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( | ||||
|         self, | ||||
|     ) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
| 
 | ||||
|         # log in as normal | ||||
|  | @ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|         d.update(build_synapse_client_resource_tree(self.hs)) | ||||
|         return d | ||||
| 
 | ||||
|     def test_get_login_flows(self): | ||||
|     def test_get_login_flows(self) -> None: | ||||
|         """GET /login should return password and SSO flows""" | ||||
|         channel = self.make_request("GET", "/_matrix/client/r0/login") | ||||
|         self.assertEqual(channel.code, 200, channel.result) | ||||
|  | @ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|     def test_multi_sso_redirect(self): | ||||
|     def test_multi_sso_redirect(self) -> None: | ||||
|         """/login/sso/redirect should redirect to an identity picker""" | ||||
|         # first hit the redirect url, which should redirect to our idp picker | ||||
|         channel = self._make_sso_redirect_request(None) | ||||
|         self.assertEqual(channel.code, 302, channel.result) | ||||
|         uri = channel.headers.getRawHeaders("Location")[0] | ||||
|         location_headers = channel.headers.getRawHeaders("Location") | ||||
|         assert location_headers | ||||
|         uri = location_headers[0] | ||||
| 
 | ||||
|         # hitting that picker should give us some HTML | ||||
|         channel = self.make_request("GET", uri) | ||||
|  | @ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) | ||||
| 
 | ||||
|     def test_multi_sso_redirect_to_cas(self): | ||||
|     def test_multi_sso_redirect_to_cas(self) -> None: | ||||
|         """If CAS is chosen, should redirect to the CAS server""" | ||||
| 
 | ||||
|         channel = self.make_request( | ||||
|  | @ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|         service_uri_params = urllib.parse.parse_qs(service_uri_query) | ||||
|         self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) | ||||
| 
 | ||||
|     def test_multi_sso_redirect_to_saml(self): | ||||
|     def test_multi_sso_redirect_to_saml(self) -> None: | ||||
|         """If SAML is chosen, should redirect to the SAML server""" | ||||
|         channel = self.make_request( | ||||
|             "GET", | ||||
|  | @ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|         relay_state_param = saml_uri_params["RelayState"][0] | ||||
|         self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) | ||||
| 
 | ||||
|     def test_login_via_oidc(self): | ||||
|     def test_login_via_oidc(self) -> None: | ||||
|         """If OIDC is chosen, should redirect to the OIDC auth endpoint""" | ||||
| 
 | ||||
|         # pick the default OIDC provider | ||||
|  | @ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEqual(chan.code, 200, chan.result) | ||||
|         self.assertEqual(chan.json_body["user_id"], "@user1:test") | ||||
| 
 | ||||
|     def test_multi_sso_redirect_to_unknown(self): | ||||
|     def test_multi_sso_redirect_to_unknown(self) -> None: | ||||
|         """An unknown IdP should cause a 400""" | ||||
|         channel = self.make_request( | ||||
|             "GET", | ||||
|  | @ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
|         ) | ||||
|         self.assertEqual(channel.code, 400, channel.result) | ||||
| 
 | ||||
|     def test_client_idp_redirect_to_unknown(self): | ||||
|     def test_client_idp_redirect_to_unknown(self) -> None: | ||||
|         """If the client tries to pick an unknown IdP, return a 404""" | ||||
|         channel = self._make_sso_redirect_request("xxx") | ||||
|         self.assertEqual(channel.code, 404, channel.result) | ||||
|         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") | ||||
| 
 | ||||
|     def test_client_idp_redirect_to_oidc(self): | ||||
|     def test_client_idp_redirect_to_oidc(self) -> None: | ||||
|         """If the client pick a known IdP, redirect to it""" | ||||
|         channel = self._make_sso_redirect_request("oidc") | ||||
|         self.assertEqual(channel.code, 302, channel.result) | ||||
|         oidc_uri = channel.headers.getRawHeaders("Location")[0] | ||||
|         location_headers = channel.headers.getRawHeaders("Location") | ||||
|         assert location_headers | ||||
|         oidc_uri = location_headers[0] | ||||
|         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) | ||||
| 
 | ||||
|         # it should redirect us to the auth page of the OIDC server | ||||
|         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) | ||||
| 
 | ||||
|     def _make_sso_redirect_request(self, idp_prov: Optional[str] = None): | ||||
|     def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: | ||||
|         """Send a request to /_matrix/client/r0/login/sso/redirect | ||||
| 
 | ||||
|         ... possibly specifying an IDP provider | ||||
|  | @ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase): | |||
|         login.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
|         self.base_url = "https://matrix.goodserver.com/" | ||||
|         self.redirect_path = "_synapse/client/login/sso/redirect/confirm" | ||||
| 
 | ||||
|  | @ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase): | |||
|         cas_user_id = "username" | ||||
|         self.user_id = "@%s:test" % cas_user_id | ||||
| 
 | ||||
|         async def get_raw(uri, args): | ||||
|         async def get_raw(uri: str, args: Any) -> bytes: | ||||
|             """Return an example response payload from a call to the `/proxyValidate` | ||||
|             endpoint of a CAS server, copied from | ||||
|             https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 | ||||
|  | @ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         return self.hs | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.deactivate_account_handler = hs.get_deactivate_account_handler() | ||||
| 
 | ||||
|     def test_cas_redirect_confirm(self): | ||||
|     def test_cas_redirect_confirm(self) -> None: | ||||
|         """Tests that the SSO login flow serves a confirmation page before redirecting a | ||||
|         user to the redirect URL. | ||||
|         """ | ||||
|  | @ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase): | |||
|             } | ||||
|         } | ||||
|     ) | ||||
|     def test_cas_redirect_whitelisted(self): | ||||
|     def test_cas_redirect_whitelisted(self) -> None: | ||||
|         """Tests that the SSO login flow serves a redirect to a whitelisted url""" | ||||
|         self._test_redirect("https://legit-site.com/") | ||||
| 
 | ||||
|     @override_config({"public_baseurl": "https://example.com"}) | ||||
|     def test_cas_redirect_login_fallback(self): | ||||
|     def test_cas_redirect_login_fallback(self) -> None: | ||||
|         self._test_redirect("https://example.com/_matrix/static/client/login") | ||||
| 
 | ||||
|     def _test_redirect(self, redirect_url): | ||||
|     def _test_redirect(self, redirect_url: str) -> None: | ||||
|         """Tests that the SSO login flow serves a redirect for the given redirect URL.""" | ||||
|         cas_ticket_url = ( | ||||
|             "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" | ||||
|  | @ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) | ||||
| 
 | ||||
|     @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) | ||||
|     def test_deactivated_user(self): | ||||
|     def test_deactivated_user(self) -> None: | ||||
|         """Logging in as a deactivated account should error.""" | ||||
|         redirect_url = "https://legit-site.com/" | ||||
| 
 | ||||
|  | @ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|         "algorithm": jwt_algorithm, | ||||
|     } | ||||
| 
 | ||||
|     def default_config(self): | ||||
|     def default_config(self) -> Dict[str, Any]: | ||||
|         config = super().default_config() | ||||
| 
 | ||||
|         # If jwt_config has been defined (eg via @override_config), don't replace it. | ||||
|  | @ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             return result.decode("ascii") | ||||
|         return result | ||||
| 
 | ||||
|     def jwt_login(self, *args): | ||||
|     def jwt_login(self, *args: Any) -> FakeChannel: | ||||
|         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | ||||
|         channel = self.make_request(b"POST", LOGIN_URL, params) | ||||
|         return channel | ||||
| 
 | ||||
|     def test_login_jwt_valid_registered(self): | ||||
|     def test_login_jwt_valid_registered(self) -> None: | ||||
|         self.register_user("kermit", "monkey") | ||||
|         channel = self.jwt_login({"sub": "kermit"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@kermit:test") | ||||
| 
 | ||||
|     def test_login_jwt_valid_unregistered(self): | ||||
|     def test_login_jwt_valid_unregistered(self) -> None: | ||||
|         channel = self.jwt_login({"sub": "frog"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@frog:test") | ||||
| 
 | ||||
|     def test_login_jwt_invalid_signature(self): | ||||
|     def test_login_jwt_invalid_signature(self) -> None: | ||||
|         channel = self.jwt_login({"sub": "frog"}, "notsecret") | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | ||||
|  | @ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             "JWT validation failed: Signature verification failed", | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_jwt_expired(self): | ||||
|     def test_login_jwt_expired(self) -> None: | ||||
|         channel = self.jwt_login({"sub": "frog", "exp": 864000}) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | ||||
|  | @ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             channel.json_body["error"], "JWT validation failed: Signature has expired" | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_jwt_not_before(self): | ||||
|     def test_login_jwt_not_before(self) -> None: | ||||
|         now = int(time.time()) | ||||
|         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|  | @ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             "JWT validation failed: The token is not yet valid (nbf)", | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_no_sub(self): | ||||
|     def test_login_no_sub(self) -> None: | ||||
|         channel = self.jwt_login({"username": "root"}) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | ||||
|         self.assertEqual(channel.json_body["error"], "Invalid JWT") | ||||
| 
 | ||||
|     @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) | ||||
|     def test_login_iss(self): | ||||
|     def test_login_iss(self) -> None: | ||||
|         """Test validating the issuer claim.""" | ||||
|         # A valid issuer. | ||||
|         channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) | ||||
|  | @ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             'JWT validation failed: Token is missing the "iss" claim', | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_iss_no_config(self): | ||||
|     def test_login_iss_no_config(self) -> None: | ||||
|         """Test providing an issuer claim without requiring it in the configuration.""" | ||||
|         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@kermit:test") | ||||
| 
 | ||||
|     @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) | ||||
|     def test_login_aud(self): | ||||
|     def test_login_aud(self) -> None: | ||||
|         """Test validating the audience claim.""" | ||||
|         # A valid audience. | ||||
|         channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) | ||||
|  | @ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             'JWT validation failed: Token is missing the "aud" claim', | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_aud_no_config(self): | ||||
|     def test_login_aud_no_config(self) -> None: | ||||
|         """Test providing an audience without requiring it in the configuration.""" | ||||
|         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|  | @ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase): | |||
|             channel.json_body["error"], "JWT validation failed: Invalid audience" | ||||
|         ) | ||||
| 
 | ||||
|     def test_login_default_sub(self): | ||||
|     def test_login_default_sub(self) -> None: | ||||
|         """Test reading user ID from the default subject claim.""" | ||||
|         channel = self.jwt_login({"sub": "kermit"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@kermit:test") | ||||
| 
 | ||||
|     @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) | ||||
|     def test_login_custom_sub(self): | ||||
|     def test_login_custom_sub(self) -> None: | ||||
|         """Test reading user ID from a custom subject claim.""" | ||||
|         channel = self.jwt_login({"username": "frog"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@frog:test") | ||||
| 
 | ||||
|     def test_login_no_token(self): | ||||
|     def test_login_no_token(self) -> None: | ||||
|         params = {"type": "org.matrix.login.jwt"} | ||||
|         channel = self.make_request(b"POST", LOGIN_URL, params) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|  | @ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): | |||
|         ] | ||||
|     ) | ||||
| 
 | ||||
|     def default_config(self): | ||||
|     def default_config(self) -> Dict[str, Any]: | ||||
|         config = super().default_config() | ||||
|         config["jwt_config"] = { | ||||
|             "enabled": True, | ||||
|  | @ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): | |||
|             return result.decode("ascii") | ||||
|         return result | ||||
| 
 | ||||
|     def jwt_login(self, *args): | ||||
|     def jwt_login(self, *args: Any) -> FakeChannel: | ||||
|         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} | ||||
|         channel = self.make_request(b"POST", LOGIN_URL, params) | ||||
|         return channel | ||||
| 
 | ||||
|     def test_login_jwt_valid(self): | ||||
|     def test_login_jwt_valid(self) -> None: | ||||
|         channel = self.jwt_login({"sub": "kermit"}) | ||||
|         self.assertEqual(channel.result["code"], b"200", channel.result) | ||||
|         self.assertEqual(channel.json_body["user_id"], "@kermit:test") | ||||
| 
 | ||||
|     def test_login_jwt_invalid_signature(self): | ||||
|     def test_login_jwt_invalid_signature(self) -> None: | ||||
|         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) | ||||
|         self.assertEqual(channel.result["code"], b"403", channel.result) | ||||
|         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | ||||
|  | @ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         register.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def make_homeserver(self, reactor, clock): | ||||
|     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
|         self.hs = self.setup_test_homeserver() | ||||
| 
 | ||||
|         self.service = ApplicationService( | ||||
|  | @ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
|         self.hs.get_datastores().main.services_cache.append(self.another_service) | ||||
|         return self.hs | ||||
| 
 | ||||
|     def test_login_appservice_user(self): | ||||
|     def test_login_appservice_user(self) -> None: | ||||
|         """Test that an appservice user can use /login""" | ||||
|         self.register_appservice_user(AS_USER, self.service.token) | ||||
| 
 | ||||
|  | @ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         self.assertEquals(channel.result["code"], b"200", channel.result) | ||||
| 
 | ||||
|     def test_login_appservice_user_bot(self): | ||||
|     def test_login_appservice_user_bot(self) -> None: | ||||
|         """Test that the appservice bot can use /login""" | ||||
|         self.register_appservice_user(AS_USER, self.service.token) | ||||
| 
 | ||||
|  | @ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         self.assertEquals(channel.result["code"], b"200", channel.result) | ||||
| 
 | ||||
|     def test_login_appservice_wrong_user(self): | ||||
|     def test_login_appservice_wrong_user(self) -> None: | ||||
|         """Test that non-as users cannot login with the as token""" | ||||
|         self.register_appservice_user(AS_USER, self.service.token) | ||||
| 
 | ||||
|  | @ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         self.assertEquals(channel.result["code"], b"403", channel.result) | ||||
| 
 | ||||
|     def test_login_appservice_wrong_as(self): | ||||
|     def test_login_appservice_wrong_as(self) -> None: | ||||
|         """Test that as users cannot login with wrong as token""" | ||||
|         self.register_appservice_user(AS_USER, self.service.token) | ||||
| 
 | ||||
|  | @ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         self.assertEquals(channel.result["code"], b"403", channel.result) | ||||
| 
 | ||||
|     def test_login_appservice_no_token(self): | ||||
|     def test_login_appservice_no_token(self) -> None: | ||||
|         """Test that users must provide a token when using the appservice | ||||
|         login method | ||||
|         """ | ||||
|  | @ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase): | |||
| 
 | ||||
|     servlets = [login.register_servlets] | ||||
| 
 | ||||
|     def default_config(self): | ||||
|     def default_config(self) -> Dict[str, Any]: | ||||
|         config = super().default_config() | ||||
|         config["public_baseurl"] = BASE_URL | ||||
| 
 | ||||
|  | @ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase): | |||
|         d.update(build_synapse_client_resource_tree(self.hs)) | ||||
|         return d | ||||
| 
 | ||||
|     def test_username_picker(self): | ||||
|     def test_username_picker(self) -> None: | ||||
|         """Test the happy path of a username picker flow.""" | ||||
| 
 | ||||
|         # do the start of the login flow | ||||
|  |  | |||
|  | @ -13,9 +13,12 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import json | ||||
| from typing import List, Optional | ||||
| 
 | ||||
| from parameterized import parameterized | ||||
| 
 | ||||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| 
 | ||||
| import synapse.rest.admin | ||||
| from synapse.api.constants import ( | ||||
|     EventContentFields, | ||||
|  | @ -24,6 +27,9 @@ from synapse.api.constants import ( | |||
|     RelationTypes, | ||||
| ) | ||||
| from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync | ||||
| from synapse.server import HomeServer | ||||
| from synapse.types import JsonDict | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| from tests import unittest | ||||
| from tests.federation.transport.test_knocking import ( | ||||
|  | @ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): | |||
|         sync.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def test_sync_argless(self): | ||||
|     def test_sync_argless(self) -> None: | ||||
|         channel = self.make_request("GET", "/sync") | ||||
| 
 | ||||
|         self.assertEqual(channel.code, 200) | ||||
|  | @ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): | |||
|         sync.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def test_sync_filter_labels(self): | ||||
|     def test_sync_filter_labels(self) -> None: | ||||
|         """Test that we can filter by a label.""" | ||||
|         sync_filter = json.dumps( | ||||
|             { | ||||
|  | @ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) | ||||
|         self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) | ||||
| 
 | ||||
|     def test_sync_filter_not_labels(self): | ||||
|     def test_sync_filter_not_labels(self) -> None: | ||||
|         """Test that we can filter by the absence of a label.""" | ||||
|         sync_filter = json.dumps( | ||||
|             { | ||||
|  | @ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): | |||
|             events[2]["content"]["body"], "with two wrong labels", events[2] | ||||
|         ) | ||||
| 
 | ||||
|     def test_sync_filter_labels_not_labels(self): | ||||
|     def test_sync_filter_labels_not_labels(self) -> None: | ||||
|         """Test that we can filter by both a label and the absence of another label.""" | ||||
|         sync_filter = json.dumps( | ||||
|             { | ||||
|  | @ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): | |||
|         self.assertEqual(len(events), 1, [event["content"] for event in events]) | ||||
|         self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) | ||||
| 
 | ||||
|     def _test_sync_filter_labels(self, sync_filter): | ||||
|     def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]: | ||||
|         user_id = self.register_user("kermit", "test") | ||||
|         tok = self.login("kermit", "test") | ||||
| 
 | ||||
|  | @ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): | |||
|     user_id = True | ||||
|     hijack_auth = False | ||||
| 
 | ||||
|     def test_sync_backwards_typing(self): | ||||
|     def test_sync_backwards_typing(self) -> None: | ||||
|         """ | ||||
|         If the typing serial goes backwards and the typing handler is then reset | ||||
|         (such as when the master restarts and sets the typing serial to 0), we | ||||
|  | @ -298,7 +304,7 @@ class SyncKnockTestCase( | |||
|         knock.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.store = hs.get_datastores().main | ||||
|         self.url = "/sync?since=%s" | ||||
|         self.next_batch = "s0" | ||||
|  | @ -336,7 +342,7 @@ class SyncKnockTestCase( | |||
|         ) | ||||
| 
 | ||||
|     @override_config({"experimental_features": {"msc2403_enabled": True}}) | ||||
|     def test_knock_room_state(self): | ||||
|     def test_knock_room_state(self) -> None: | ||||
|         """Tests that /sync returns state from a room after knocking on it.""" | ||||
|         # Knock on a room | ||||
|         channel = self.make_request( | ||||
|  | @ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): | |||
|         sync.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.url = "/sync?since=%s" | ||||
|         self.next_batch = "s0" | ||||
| 
 | ||||
|  | @ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): | |||
|         self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) | ||||
| 
 | ||||
|     @override_config({"experimental_features": {"msc2285_enabled": True}}) | ||||
|     def test_hidden_read_receipts(self): | ||||
|     def test_hidden_read_receipts(self) -> None: | ||||
|         # Send a message as the first user | ||||
|         res = self.helper.send(self.room_id, body="hello", tok=self.tok) | ||||
| 
 | ||||
|  | @ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): | |||
|         ] | ||||
|     ) | ||||
|     def test_read_receipt_with_empty_body( | ||||
|         self, name, user_agent: str, expected_status_code: int | ||||
|     ): | ||||
|         self, name: str, user_agent: str, expected_status_code: int | ||||
|     ) -> None: | ||||
|         # Send a message as the first user | ||||
|         res = self.helper.send(self.room_id, body="hello", tok=self.tok) | ||||
| 
 | ||||
|  | @ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): | |||
|         ) | ||||
|         self.assertEqual(channel.code, expected_status_code) | ||||
| 
 | ||||
|     def _get_read_receipt(self): | ||||
|     def _get_read_receipt(self) -> Optional[JsonDict]: | ||||
|         """Syncs and returns the read receipt.""" | ||||
| 
 | ||||
|         # Checks if event is a read receipt | ||||
|         def is_read_receipt(event): | ||||
|         def is_read_receipt(event: JsonDict) -> bool: | ||||
|             return event["type"] == "m.receipt" | ||||
| 
 | ||||
|         # Sync | ||||
|  | @ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): | |||
|         ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ | ||||
|             "ephemeral" | ||||
|         ]["events"] | ||||
|         return next(filter(is_read_receipt, ephemeral_events), None) | ||||
|         receipt_event = filter(is_read_receipt, ephemeral_events) | ||||
|         return next(receipt_event, None) | ||||
| 
 | ||||
| 
 | ||||
| class UnreadMessagesTestCase(unittest.HomeserverTestCase): | ||||
|  | @ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): | |||
|         receipts.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def prepare(self, reactor, clock, hs): | ||||
|     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
|         self.url = "/sync?since=%s" | ||||
|         self.next_batch = "s0" | ||||
| 
 | ||||
|  | @ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): | |||
|             tok=self.tok, | ||||
|         ) | ||||
| 
 | ||||
|     def test_unread_counts(self): | ||||
|     def test_unread_counts(self) -> None: | ||||
|         """Tests that /sync returns the right value for the unread count (MSC2654).""" | ||||
| 
 | ||||
|         # Check that our own messages don't increase the unread count. | ||||
|  | @ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): | |||
|         ) | ||||
|         self._check_unread_count(5) | ||||
| 
 | ||||
|     def _check_unread_count(self, expected_count: int): | ||||
|     def _check_unread_count(self, expected_count: int) -> None: | ||||
|         """Syncs and compares the unread count with the expected value.""" | ||||
| 
 | ||||
|         channel = self.make_request( | ||||
|  | @ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): | |||
|         sync.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def test_noop_sync_does_not_tightloop(self): | ||||
|     def test_noop_sync_does_not_tightloop(self) -> None: | ||||
|         """If the sync times out, we shouldn't cache the result | ||||
| 
 | ||||
|         Essentially a regression test for #8518. | ||||
|  | @ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): | |||
|         devices.register_servlets, | ||||
|     ] | ||||
| 
 | ||||
|     def test_user_with_no_rooms_receives_self_device_list_updates(self): | ||||
|     def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: | ||||
|         """Tests that a user with no rooms still receives their own device list updates""" | ||||
|         device_id = "TESTDEVICE" | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Dirk Klimpel
						Dirk Klimpel