Add type hints to `tests/rest/client` (#12066)
parent
5b2b36809f
commit
64c73c6ac8
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `tests/rest/client`.
|
|
@ -13,17 +13,21 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from twisted.internet.defer import succeed
|
from twisted.internet.defer import succeed
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.rest.client import account, auth, devices, login, logout, register
|
from synapse.rest.client import account, auth, devices, login, logout, register
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.handlers.test_oidc import HAS_OIDC
|
from tests.handlers.test_oidc import HAS_OIDC
|
||||||
|
@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless
|
||||||
|
|
||||||
|
|
||||||
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: HomeServer) -> None:
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.recaptcha_attempts = []
|
self.recaptcha_attempts: List[Tuple[dict, str]] = []
|
||||||
|
|
||||||
def check_auth(self, authdict, clientip):
|
def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||||
self.recaptcha_attempts.append((authdict, clientip))
|
self.recaptcha_attempts.append((authdict, clientip))
|
||||||
return succeed(True)
|
return succeed(True)
|
||||||
|
|
||||||
|
@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
|
|
||||||
|
@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
hs = self.setup_test_homeserver(config=config)
|
hs = self.setup_test_homeserver(config=config)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.recaptcha_checker = DummyRecaptchaChecker(hs)
|
self.recaptcha_checker = DummyRecaptchaChecker(hs)
|
||||||
auth_handler = hs.get_auth_handler()
|
auth_handler = hs.get_auth_handler()
|
||||||
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
|
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
|
||||||
|
@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(attempts), 1)
|
self.assertEqual(len(attempts), 1)
|
||||||
self.assertEqual(attempts[0][0]["response"], "a")
|
self.assertEqual(attempts[0][0]["response"], "a")
|
||||||
|
|
||||||
def test_fallback_captcha(self):
|
def test_fallback_captcha(self) -> None:
|
||||||
"""Ensure that fallback auth via a captcha works."""
|
"""Ensure that fallback auth via a captcha works."""
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
channel = self.register(
|
channel = self.register(
|
||||||
|
@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
# We're given a registered user.
|
# We're given a registered user.
|
||||||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||||
|
|
||||||
def test_complete_operation_unknown_session(self):
|
def test_complete_operation_unknown_session(self) -> None:
|
||||||
"""
|
"""
|
||||||
Attempting to mark an invalid session as complete should error.
|
Attempting to mark an invalid session as complete should error.
|
||||||
"""
|
"""
|
||||||
|
@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
|
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
|
||||||
|
@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def create_resource_dict(self):
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
resource_dict = super().create_resource_dict()
|
resource_dict = super().create_resource_dict()
|
||||||
resource_dict.update(build_synapse_client_resource_tree(self.hs))
|
resource_dict.update(build_synapse_client_resource_tree(self.hs))
|
||||||
return resource_dict
|
return resource_dict
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_pass = "pass"
|
self.user_pass = "pass"
|
||||||
self.user = self.register_user("test", self.user_pass)
|
self.user = self.register_user("test", self.user_pass)
|
||||||
self.device_id = "dev1"
|
self.device_id = "dev1"
|
||||||
|
@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def test_ui_auth(self):
|
def test_ui_auth(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test user interactive authentication outside of registration.
|
Test user interactive authentication outside of registration.
|
||||||
"""
|
"""
|
||||||
|
@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_grandfathered_identifier(self):
|
def test_grandfathered_identifier(self) -> None:
|
||||||
"""Check behaviour without "identifier" dict
|
"""Check behaviour without "identifier" dict
|
||||||
|
|
||||||
Synapse used to require clients to submit a "user" field for m.login.password
|
Synapse used to require clients to submit a "user" field for m.login.password
|
||||||
|
@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_can_change_body(self):
|
def test_can_change_body(self) -> None:
|
||||||
"""
|
"""
|
||||||
The client dict can be modified during the user interactive authentication session.
|
The client dict can be modified during the user interactive authentication session.
|
||||||
|
|
||||||
|
@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cannot_change_uri(self):
|
def test_cannot_change_uri(self) -> None:
|
||||||
"""
|
"""
|
||||||
The initial requested URI cannot be modified during the user interactive authentication session.
|
The initial requested URI cannot be modified during the user interactive authentication session.
|
||||||
"""
|
"""
|
||||||
|
@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
|
@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
|
||||||
def test_can_reuse_session(self):
|
def test_can_reuse_session(self) -> None:
|
||||||
"""
|
"""
|
||||||
The session can be reused if configured.
|
The session can be reused if configured.
|
||||||
|
|
||||||
|
@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||||
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
||||||
def test_ui_auth_via_sso(self):
|
def test_ui_auth_via_sso(self) -> None:
|
||||||
"""Test a successful UI Auth flow via SSO
|
"""Test a successful UI Auth flow via SSO
|
||||||
|
|
||||||
This includes:
|
This includes:
|
||||||
|
@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||||
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
||||||
def test_does_not_offer_password_for_sso_user(self):
|
def test_does_not_offer_password_for_sso_user(self) -> None:
|
||||||
login_resp = self.helper.login_via_oidc("username")
|
login_resp = self.helper.login_via_oidc("username")
|
||||||
user_tok = login_resp["access_token"]
|
user_tok = login_resp["access_token"]
|
||||||
device_id = login_resp["device_id"]
|
device_id = login_resp["device_id"]
|
||||||
|
@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
flows = channel.json_body["flows"]
|
flows = channel.json_body["flows"]
|
||||||
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||||
|
|
||||||
def test_does_not_offer_sso_for_password_user(self):
|
def test_does_not_offer_sso_for_password_user(self) -> None:
|
||||||
channel = self.delete_device(
|
channel = self.delete_device(
|
||||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||||
)
|
)
|
||||||
|
@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||||
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
||||||
def test_offers_both_flows_for_upgraded_user(self):
|
def test_offers_both_flows_for_upgraded_user(self) -> None:
|
||||||
"""A user that had a password and then logged in with SSO should get both flows"""
|
"""A user that had a password and then logged in with SSO should get both flows"""
|
||||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||||
self.assertEqual(login_resp["user_id"], self.user)
|
self.assertEqual(login_resp["user_id"], self.user)
|
||||||
|
@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||||
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
||||||
def test_ui_auth_fails_for_incorrect_sso_user(self):
|
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
|
||||||
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
|
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
|
||||||
# log the user in
|
# log the user in
|
||||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||||
|
@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.user_pass = "pass"
|
self.user_pass = "pass"
|
||||||
self.user = self.register_user("test", self.user_pass)
|
self.user = self.register_user("test", self.user_pass)
|
||||||
|
|
||||||
|
@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
{"refresh_token": refresh_token},
|
{"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_access_token_valid(self, access_token) -> bool:
|
def is_access_token_valid(self, access_token: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks whether an access token is valid, returning whether it is or not.
|
Checks whether an access token is valid, returning whether it is or not.
|
||||||
"""
|
"""
|
||||||
|
@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return code == HTTPStatus.OK
|
return code == HTTPStatus.OK
|
||||||
|
|
||||||
def test_login_issue_refresh_token(self):
|
def test_login_issue_refresh_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
A login response should include a refresh_token only if asked.
|
A login response should include a refresh_token only if asked.
|
||||||
"""
|
"""
|
||||||
|
@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
self.assertIn("refresh_token", login_with_refresh.json_body)
|
self.assertIn("refresh_token", login_with_refresh.json_body)
|
||||||
self.assertIn("expires_in_ms", login_with_refresh.json_body)
|
self.assertIn("expires_in_ms", login_with_refresh.json_body)
|
||||||
|
|
||||||
def test_register_issue_refresh_token(self):
|
def test_register_issue_refresh_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
A register response should include a refresh_token only if asked.
|
A register response should include a refresh_token only if asked.
|
||||||
"""
|
"""
|
||||||
|
@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
self.assertIn("refresh_token", register_with_refresh.json_body)
|
self.assertIn("refresh_token", register_with_refresh.json_body)
|
||||||
self.assertIn("expires_in_ms", register_with_refresh.json_body)
|
self.assertIn("expires_in_ms", register_with_refresh.json_body)
|
||||||
|
|
||||||
def test_token_refresh(self):
|
def test_token_refresh(self) -> None:
|
||||||
"""
|
"""
|
||||||
A refresh token can be used to issue a new access token.
|
A refresh token can be used to issue a new access token.
|
||||||
"""
|
"""
|
||||||
|
@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"refreshable_access_token_lifetime": "1m"})
|
@override_config({"refreshable_access_token_lifetime": "1m"})
|
||||||
def test_refreshable_access_token_expiration(self):
|
def test_refreshable_access_token_expiration(self) -> None:
|
||||||
"""
|
"""
|
||||||
The access token should have some time as specified in the config.
|
The access token should have some time as specified in the config.
|
||||||
"""
|
"""
|
||||||
|
@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
"nonrefreshable_access_token_lifetime": "10m",
|
"nonrefreshable_access_token_lifetime": "10m",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
|
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that the expiry times for refreshable and non-refreshable access
|
Tests that the expiry times for refreshable and non-refreshable access
|
||||||
tokens can be different.
|
tokens can be different.
|
||||||
|
@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
|
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
|
||||||
)
|
)
|
||||||
def test_refresh_token_expiry(self):
|
def test_refresh_token_expiry(self) -> None:
|
||||||
"""
|
"""
|
||||||
The refresh token can be configured to have a limited lifetime.
|
The refresh token can be configured to have a limited lifetime.
|
||||||
When that lifetime has ended, the refresh token can no longer be used to
|
When that lifetime has ended, the refresh token can no longer be used to
|
||||||
|
@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
"session_lifetime": "3m",
|
"session_lifetime": "3m",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_ultimate_session_expiry(self):
|
def test_ultimate_session_expiry(self) -> None:
|
||||||
"""
|
"""
|
||||||
The session can be configured to have an ultimate, limited lifetime.
|
The session can be configured to have an ultimate, limited lifetime.
|
||||||
"""
|
"""
|
||||||
|
@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
|
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_refresh_token_invalidation(self):
|
def test_refresh_token_invalidation(self) -> None:
|
||||||
"""Refresh tokens are invalidated after first use of the next token.
|
"""Refresh tokens are invalidated after first use of the next token.
|
||||||
|
|
||||||
A refresh token is considered invalid if:
|
A refresh token is considered invalid if:
|
||||||
|
@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
|
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_many_token_refresh(self):
|
def test_many_token_refresh(self) -> None:
|
||||||
"""
|
"""
|
||||||
If a refresh is performed many times during a session, there shouldn't be
|
If a refresh is performed many times during a session, there shouldn't be
|
||||||
extra 'cruft' built up over time.
|
extra 'cruft' built up over time.
|
||||||
|
|
|
@ -13,9 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.rest.client import capabilities, login
|
from synapse.rest.client import capabilities, login
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.url = b"/capabilities"
|
self.url = b"/capabilities"
|
||||||
hs = self.setup_test_homeserver()
|
hs = self.setup_test_homeserver()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.localpart = "user"
|
self.localpart = "user"
|
||||||
self.password = "pass"
|
self.password = "pass"
|
||||||
self.user = self.register_user(self.localpart, self.password)
|
self.user = self.register_user(self.localpart, self.password)
|
||||||
|
|
||||||
def test_check_auth_required(self):
|
def test_check_auth_required(self) -> None:
|
||||||
channel = self.make_request("GET", self.url)
|
channel = self.make_request("GET", self.url)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 401)
|
self.assertEqual(channel.code, 401)
|
||||||
|
|
||||||
def test_get_room_version_capabilities(self):
|
def test_get_room_version_capabilities(self) -> None:
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||||
|
@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
capabilities["m.room_versions"]["default"],
|
capabilities["m.room_versions"]["default"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_change_password_capabilities_password_login(self):
|
def test_get_change_password_capabilities_password_login(self) -> None:
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
channel = self.make_request("GET", self.url, access_token=access_token)
|
channel = self.make_request("GET", self.url, access_token=access_token)
|
||||||
|
@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(capabilities["m.change_password"]["enabled"])
|
self.assertTrue(capabilities["m.change_password"]["enabled"])
|
||||||
|
|
||||||
@override_config({"password_config": {"localdb_enabled": False}})
|
@override_config({"password_config": {"localdb_enabled": False}})
|
||||||
def test_get_change_password_capabilities_localdb_disabled(self):
|
def test_get_change_password_capabilities_localdb_disabled(self) -> None:
|
||||||
access_token = self.get_success(
|
access_token = self.get_success(
|
||||||
self.auth_handler.create_access_token_for_user_id(
|
self.auth_handler.create_access_token_for_user_id(
|
||||||
self.user, device_id=None, valid_until_ms=None
|
self.user, device_id=None, valid_until_ms=None
|
||||||
|
@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
||||||
|
|
||||||
@override_config({"password_config": {"enabled": False}})
|
@override_config({"password_config": {"enabled": False}})
|
||||||
def test_get_change_password_capabilities_password_disabled(self):
|
def test_get_change_password_capabilities_password_disabled(self) -> None:
|
||||||
access_token = self.get_success(
|
access_token = self.get_success(
|
||||||
self.auth_handler.create_access_token_for_user_id(
|
self.auth_handler.create_access_token_for_user_id(
|
||||||
self.user, device_id=None, valid_until_ms=None
|
self.user, device_id=None, valid_until_ms=None
|
||||||
|
@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
self.assertFalse(capabilities["m.change_password"]["enabled"])
|
||||||
|
|
||||||
def test_get_change_users_attributes_capabilities(self):
|
def test_get_change_users_attributes_capabilities(self) -> None:
|
||||||
"""Test that server returns capabilities by default."""
|
"""Test that server returns capabilities by default."""
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
|
@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
|
self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
|
||||||
|
|
||||||
@override_config({"enable_set_displayname": False})
|
@override_config({"enable_set_displayname": False})
|
||||||
def test_get_set_displayname_capabilities_displayname_disabled(self):
|
def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
|
||||||
"""Test if set displayname is disabled that the server responds it."""
|
"""Test if set displayname is disabled that the server responds it."""
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
|
@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
|
self.assertFalse(capabilities["m.set_displayname"]["enabled"])
|
||||||
|
|
||||||
@override_config({"enable_set_avatar_url": False})
|
@override_config({"enable_set_avatar_url": False})
|
||||||
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
|
def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
|
||||||
"""Test if set avatar_url is disabled that the server responds it."""
|
"""Test if set avatar_url is disabled that the server responds it."""
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
|
@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
|
||||||
|
|
||||||
@override_config({"enable_3pid_changes": False})
|
@override_config({"enable_3pid_changes": False})
|
||||||
def test_get_change_3pid_capabilities_3pid_disabled(self):
|
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
|
||||||
"""Test if change 3pid is disabled that the server responds it."""
|
"""Test if change 3pid is disabled that the server responds it."""
|
||||||
access_token = self.login(self.localpart, self.password)
|
access_token = self.login(self.localpart, self.password)
|
||||||
|
|
||||||
|
@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
|
self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3244_enabled": False}})
|
@override_config({"experimental_features": {"msc3244_enabled": False}})
|
||||||
def test_get_does_not_include_msc3244_fields_when_disabled(self):
|
def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None:
|
||||||
access_token = self.get_success(
|
access_token = self.get_success(
|
||||||
self.auth_handler.create_access_token_for_user_id(
|
self.auth_handler.create_access_token_for_user_id(
|
||||||
self.user, device_id=None, valid_until_ms=None
|
self.user, device_id=None, valid_until_ms=None
|
||||||
|
@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
|
||||||
"org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
|
"org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_does_include_msc3244_fields_when_enabled(self):
|
def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
|
||||||
access_token = self.get_success(
|
access_token = self.get_success(
|
||||||
self.auth_handler.create_access_token_for_user_id(
|
self.auth_handler.create_access_token_for_user_id(
|
||||||
self.user, device_id=None, valid_until_ms=None
|
self.user, device_id=None, valid_until_ms=None
|
||||||
|
|
|
@ -20,6 +20,7 @@ from urllib.parse import urlencode
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import devices, login, logout, register
|
from synapse.rest.client import devices, login, logout, register
|
||||||
from synapse.rest.client.account import WhoamiRestServlet
|
from synapse.rest.client.account import WhoamiRestServlet
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import create_requester
|
from synapse.types import create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.handlers.test_oidc import HAS_OIDC
|
from tests.handlers.test_oidc import HAS_OIDC
|
||||||
from tests.handlers.test_saml import has_saml2
|
from tests.handlers.test_saml import has_saml2
|
||||||
from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
|
from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
|
||||||
|
from tests.server import FakeChannel
|
||||||
from tests.test_utils.html_parsers import TestHtmlParser
|
from tests.test_utils.html_parsers import TestHtmlParser
|
||||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||||
|
|
||||||
|
@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
|
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.hs = self.setup_test_homeserver()
|
self.hs = self.setup_test_homeserver()
|
||||||
self.hs.config.registration.enable_registration = True
|
self.hs.config.registration.enable_registration = True
|
||||||
self.hs.config.registration.registrations_require_3pid = []
|
self.hs.config.registration.registrations_require_3pid = []
|
||||||
|
@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_POST_ratelimiting_per_address(self):
|
def test_POST_ratelimiting_per_address(self) -> None:
|
||||||
# Create different users so we're sure not to be bothered by the per-user
|
# Create different users so we're sure not to be bothered by the per-user
|
||||||
# ratelimiter.
|
# ratelimiter.
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
|
@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_POST_ratelimiting_per_account(self):
|
def test_POST_ratelimiting_per_account(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
|
@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_POST_ratelimiting_per_account_failed_attempts(self):
|
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
|
@ -243,7 +247,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
|
|
||||||
@override_config({"session_lifetime": "24h"})
|
@override_config({"session_lifetime": "24h"})
|
||||||
def test_soft_logout(self):
|
def test_soft_logout(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
# we shouldn't be able to make requests without an access token
|
# we shouldn't be able to make requests without an access token
|
||||||
|
@ -298,7 +302,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEquals(channel.json_body["soft_logout"], False)
|
self.assertEquals(channel.json_body["soft_logout"], False)
|
||||||
|
|
||||||
def _delete_device(self, access_token, user_id, password, device_id):
|
def _delete_device(
|
||||||
|
self, access_token: str, user_id: str, password: str, device_id: str
|
||||||
|
) -> None:
|
||||||
"""Perform the UI-Auth to delete a device"""
|
"""Perform the UI-Auth to delete a device"""
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
b"DELETE", "devices/" + device_id, access_token=access_token
|
b"DELETE", "devices/" + device_id, access_token=access_token
|
||||||
|
@ -329,7 +335,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(channel.code, 200, channel.result)
|
self.assertEquals(channel.code, 200, channel.result)
|
||||||
|
|
||||||
@override_config({"session_lifetime": "24h"})
|
@override_config({"session_lifetime": "24h"})
|
||||||
def test_session_can_hard_logout_after_being_soft_logged_out(self):
|
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
# log in as normal
|
# log in as normal
|
||||||
|
@ -353,7 +359,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
@override_config({"session_lifetime": "24h"})
|
@override_config({"session_lifetime": "24h"})
|
||||||
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
|
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
# log in as normal
|
# log in as normal
|
||||||
|
@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
d.update(build_synapse_client_resource_tree(self.hs))
|
d.update(build_synapse_client_resource_tree(self.hs))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def test_get_login_flows(self):
|
def test_get_login_flows(self) -> None:
|
||||||
"""GET /login should return password and SSO flows"""
|
"""GET /login should return password and SSO flows"""
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multi_sso_redirect(self):
|
def test_multi_sso_redirect(self) -> None:
|
||||||
"""/login/sso/redirect should redirect to an identity picker"""
|
"""/login/sso/redirect should redirect to an identity picker"""
|
||||||
# first hit the redirect url, which should redirect to our idp picker
|
# first hit the redirect url, which should redirect to our idp picker
|
||||||
channel = self._make_sso_redirect_request(None)
|
channel = self._make_sso_redirect_request(None)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
uri = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
uri = location_headers[0]
|
||||||
|
|
||||||
# hitting that picker should give us some HTML
|
# hitting that picker should give us some HTML
|
||||||
channel = self.make_request("GET", uri)
|
channel = self.make_request("GET", uri)
|
||||||
|
@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
|
self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
|
||||||
|
|
||||||
def test_multi_sso_redirect_to_cas(self):
|
def test_multi_sso_redirect_to_cas(self) -> None:
|
||||||
"""If CAS is chosen, should redirect to the CAS server"""
|
"""If CAS is chosen, should redirect to the CAS server"""
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
service_uri_params = urllib.parse.parse_qs(service_uri_query)
|
service_uri_params = urllib.parse.parse_qs(service_uri_query)
|
||||||
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
|
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
|
||||||
|
|
||||||
def test_multi_sso_redirect_to_saml(self):
|
def test_multi_sso_redirect_to_saml(self) -> None:
|
||||||
"""If SAML is chosen, should redirect to the SAML server"""
|
"""If SAML is chosen, should redirect to the SAML server"""
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
relay_state_param = saml_uri_params["RelayState"][0]
|
relay_state_param = saml_uri_params["RelayState"][0]
|
||||||
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
|
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
|
||||||
|
|
||||||
def test_login_via_oidc(self):
|
def test_login_via_oidc(self) -> None:
|
||||||
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
|
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
|
||||||
|
|
||||||
# pick the default OIDC provider
|
# pick the default OIDC provider
|
||||||
|
@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||||
|
|
||||||
def test_multi_sso_redirect_to_unknown(self):
|
def test_multi_sso_redirect_to_unknown(self) -> None:
|
||||||
"""An unknown IdP should cause a 400"""
|
"""An unknown IdP should cause a 400"""
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
|
||||||
def test_client_idp_redirect_to_unknown(self):
|
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
channel = self._make_sso_redirect_request("xxx")
|
channel = self._make_sso_redirect_request("xxx")
|
||||||
self.assertEqual(channel.code, 404, channel.result)
|
self.assertEqual(channel.code, 404, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
def test_client_idp_redirect_to_oidc(self):
|
def test_client_idp_redirect_to_oidc(self) -> None:
|
||||||
"""If the client pick a known IdP, redirect to it"""
|
"""If the client pick a known IdP, redirect to it"""
|
||||||
channel = self._make_sso_redirect_request("oidc")
|
channel = self._make_sso_redirect_request("oidc")
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, 302, channel.result)
|
||||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
|
assert location_headers
|
||||||
|
oidc_uri = location_headers[0]
|
||||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||||
|
|
||||||
# it should redirect us to the auth page of the OIDC server
|
# it should redirect us to the auth page of the OIDC server
|
||||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||||
|
|
||||||
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None):
|
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
|
||||||
"""Send a request to /_matrix/client/r0/login/sso/redirect
|
"""Send a request to /_matrix/client/r0/login/sso/redirect
|
||||||
|
|
||||||
... possibly specifying an IDP provider
|
... possibly specifying an IDP provider
|
||||||
|
@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.base_url = "https://matrix.goodserver.com/"
|
self.base_url = "https://matrix.goodserver.com/"
|
||||||
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
|
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
|
||||||
|
|
||||||
|
@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
cas_user_id = "username"
|
cas_user_id = "username"
|
||||||
self.user_id = "@%s:test" % cas_user_id
|
self.user_id = "@%s:test" % cas_user_id
|
||||||
|
|
||||||
async def get_raw(uri, args):
|
async def get_raw(uri: str, args: Any) -> bytes:
|
||||||
"""Return an example response payload from a call to the `/proxyValidate`
|
"""Return an example response payload from a call to the `/proxyValidate`
|
||||||
endpoint of a CAS server, copied from
|
endpoint of a CAS server, copied from
|
||||||
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
|
https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
|
||||||
|
@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
self.deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
|
|
||||||
def test_cas_redirect_confirm(self):
|
def test_cas_redirect_confirm(self) -> None:
|
||||||
"""Tests that the SSO login flow serves a confirmation page before redirecting a
|
"""Tests that the SSO login flow serves a confirmation page before redirecting a
|
||||||
user to the redirect URL.
|
user to the redirect URL.
|
||||||
"""
|
"""
|
||||||
|
@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def test_cas_redirect_whitelisted(self):
|
def test_cas_redirect_whitelisted(self) -> None:
|
||||||
"""Tests that the SSO login flow serves a redirect to a whitelisted url"""
|
"""Tests that the SSO login flow serves a redirect to a whitelisted url"""
|
||||||
self._test_redirect("https://legit-site.com/")
|
self._test_redirect("https://legit-site.com/")
|
||||||
|
|
||||||
@override_config({"public_baseurl": "https://example.com"})
|
@override_config({"public_baseurl": "https://example.com"})
|
||||||
def test_cas_redirect_login_fallback(self):
|
def test_cas_redirect_login_fallback(self) -> None:
|
||||||
self._test_redirect("https://example.com/_matrix/static/client/login")
|
self._test_redirect("https://example.com/_matrix/static/client/login")
|
||||||
|
|
||||||
def _test_redirect(self, redirect_url):
|
def _test_redirect(self, redirect_url: str) -> None:
|
||||||
"""Tests that the SSO login flow serves a redirect for the given redirect URL."""
|
"""Tests that the SSO login flow serves a redirect for the given redirect URL."""
|
||||||
cas_ticket_url = (
|
cas_ticket_url = (
|
||||||
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
|
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
|
||||||
|
@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
||||||
|
|
||||||
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
||||||
def test_deactivated_user(self):
|
def test_deactivated_user(self) -> None:
|
||||||
"""Logging in as a deactivated account should error."""
|
"""Logging in as a deactivated account should error."""
|
||||||
redirect_url = "https://legit-site.com/"
|
redirect_url = "https://legit-site.com/"
|
||||||
|
|
||||||
|
@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
"algorithm": jwt_algorithm,
|
"algorithm": jwt_algorithm,
|
||||||
}
|
}
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
|
|
||||||
# If jwt_config has been defined (eg via @override_config), don't replace it.
|
# If jwt_config has been defined (eg via @override_config), don't replace it.
|
||||||
|
@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
return result.decode("ascii")
|
return result.decode("ascii")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def jwt_login(self, *args):
|
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def test_login_jwt_valid_registered(self):
|
def test_login_jwt_valid_registered(self) -> None:
|
||||||
self.register_user("kermit", "monkey")
|
self.register_user("kermit", "monkey")
|
||||||
channel = self.jwt_login({"sub": "kermit"})
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
def test_login_jwt_valid_unregistered(self):
|
def test_login_jwt_valid_unregistered(self) -> None:
|
||||||
channel = self.jwt_login({"sub": "frog"})
|
channel = self.jwt_login({"sub": "frog"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||||
|
|
||||||
def test_login_jwt_invalid_signature(self):
|
def test_login_jwt_invalid_signature(self) -> None:
|
||||||
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
"JWT validation failed: Signature verification failed",
|
"JWT validation failed: Signature verification failed",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_jwt_expired(self):
|
def test_login_jwt_expired(self) -> None:
|
||||||
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
channel.json_body["error"], "JWT validation failed: Signature has expired"
|
channel.json_body["error"], "JWT validation failed: Signature has expired"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_jwt_not_before(self):
|
def test_login_jwt_not_before(self) -> None:
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
|
@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
"JWT validation failed: The token is not yet valid (nbf)",
|
"JWT validation failed: The token is not yet valid (nbf)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_no_sub(self):
|
def test_login_no_sub(self) -> None:
|
||||||
channel = self.jwt_login({"username": "root"})
|
channel = self.jwt_login({"username": "root"})
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
||||||
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
|
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
|
||||||
def test_login_iss(self):
|
def test_login_iss(self) -> None:
|
||||||
"""Test validating the issuer claim."""
|
"""Test validating the issuer claim."""
|
||||||
# A valid issuer.
|
# A valid issuer.
|
||||||
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
|
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
|
||||||
|
@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
'JWT validation failed: Token is missing the "iss" claim',
|
'JWT validation failed: Token is missing the "iss" claim',
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_iss_no_config(self):
|
def test_login_iss_no_config(self) -> None:
|
||||||
"""Test providing an issuer claim without requiring it in the configuration."""
|
"""Test providing an issuer claim without requiring it in the configuration."""
|
||||||
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
|
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
|
||||||
def test_login_aud(self):
|
def test_login_aud(self) -> None:
|
||||||
"""Test validating the audience claim."""
|
"""Test validating the audience claim."""
|
||||||
# A valid audience.
|
# A valid audience.
|
||||||
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
|
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
|
||||||
|
@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
'JWT validation failed: Token is missing the "aud" claim',
|
'JWT validation failed: Token is missing the "aud" claim',
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_aud_no_config(self):
|
def test_login_aud_no_config(self) -> None:
|
||||||
"""Test providing an audience without requiring it in the configuration."""
|
"""Test providing an audience without requiring it in the configuration."""
|
||||||
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
|
@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_login_default_sub(self):
|
def test_login_default_sub(self) -> None:
|
||||||
"""Test reading user ID from the default subject claim."""
|
"""Test reading user ID from the default subject claim."""
|
||||||
channel = self.jwt_login({"sub": "kermit"})
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
|
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
|
||||||
def test_login_custom_sub(self):
|
def test_login_custom_sub(self) -> None:
|
||||||
"""Test reading user ID from a custom subject claim."""
|
"""Test reading user ID from a custom subject claim."""
|
||||||
channel = self.jwt_login({"username": "frog"})
|
channel = self.jwt_login({"username": "frog"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||||
|
|
||||||
def test_login_no_token(self):
|
def test_login_no_token(self) -> None:
|
||||||
params = {"type": "org.matrix.login.jwt"}
|
params = {"type": "org.matrix.login.jwt"}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
|
@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["jwt_config"] = {
|
config["jwt_config"] = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
|
@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
return result.decode("ascii")
|
return result.decode("ascii")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def jwt_login(self, *args):
|
def jwt_login(self, *args: Any) -> FakeChannel:
|
||||||
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
def test_login_jwt_valid(self):
|
def test_login_jwt_valid(self) -> None:
|
||||||
channel = self.jwt_login({"sub": "kermit"})
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
def test_login_jwt_invalid_signature(self):
|
def test_login_jwt_invalid_signature(self) -> None:
|
||||||
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
||||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.hs = self.setup_test_homeserver()
|
self.hs = self.setup_test_homeserver()
|
||||||
|
|
||||||
self.service = ApplicationService(
|
self.service = ApplicationService(
|
||||||
|
@ -1105,7 +1117,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.hs.get_datastores().main.services_cache.append(self.another_service)
|
self.hs.get_datastores().main.services_cache.append(self.another_service)
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_login_appservice_user(self):
|
def test_login_appservice_user(self) -> None:
|
||||||
"""Test that an appservice user can use /login"""
|
"""Test that an appservice user can use /login"""
|
||||||
self.register_appservice_user(AS_USER, self.service.token)
|
self.register_appservice_user(AS_USER, self.service.token)
|
||||||
|
|
||||||
|
@ -1119,7 +1131,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
def test_login_appservice_user_bot(self):
|
def test_login_appservice_user_bot(self) -> None:
|
||||||
"""Test that the appservice bot can use /login"""
|
"""Test that the appservice bot can use /login"""
|
||||||
self.register_appservice_user(AS_USER, self.service.token)
|
self.register_appservice_user(AS_USER, self.service.token)
|
||||||
|
|
||||||
|
@ -1133,7 +1145,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
def test_login_appservice_wrong_user(self):
|
def test_login_appservice_wrong_user(self) -> None:
|
||||||
"""Test that non-as users cannot login with the as token"""
|
"""Test that non-as users cannot login with the as token"""
|
||||||
self.register_appservice_user(AS_USER, self.service.token)
|
self.register_appservice_user(AS_USER, self.service.token)
|
||||||
|
|
||||||
|
@ -1147,7 +1159,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
|
|
||||||
def test_login_appservice_wrong_as(self):
|
def test_login_appservice_wrong_as(self) -> None:
|
||||||
"""Test that as users cannot login with wrong as token"""
|
"""Test that as users cannot login with wrong as token"""
|
||||||
self.register_appservice_user(AS_USER, self.service.token)
|
self.register_appservice_user(AS_USER, self.service.token)
|
||||||
|
|
||||||
|
@ -1161,7 +1173,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
|
|
||||||
def test_login_appservice_no_token(self):
|
def test_login_appservice_no_token(self) -> None:
|
||||||
"""Test that users must provide a token when using the appservice
|
"""Test that users must provide a token when using the appservice
|
||||||
login method
|
login method
|
||||||
"""
|
"""
|
||||||
|
@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [login.register_servlets]
|
servlets = [login.register_servlets]
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
|
|
||||||
|
@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
d.update(build_synapse_client_resource_tree(self.hs))
|
d.update(build_synapse_client_resource_tree(self.hs))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def test_username_picker(self):
|
def test_username_picker(self) -> None:
|
||||||
"""Test the happy path of a username picker flow."""
|
"""Test the happy path of a username picker flow."""
|
||||||
|
|
||||||
# do the start of the login flow
|
# do the start of the login flow
|
||||||
|
|
|
@ -13,9 +13,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventContentFields,
|
EventContentFields,
|
||||||
|
@ -24,6 +27,9 @@ from synapse.api.constants import (
|
||||||
RelationTypes,
|
RelationTypes,
|
||||||
)
|
)
|
||||||
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
|
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.federation.transport.test_knocking import (
|
from tests.federation.transport.test_knocking import (
|
||||||
|
@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_sync_argless(self):
|
def test_sync_argless(self) -> None:
|
||||||
channel = self.make_request("GET", "/sync")
|
channel = self.make_request("GET", "/sync")
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_sync_filter_labels(self):
|
def test_sync_filter_labels(self) -> None:
|
||||||
"""Test that we can filter by a label."""
|
"""Test that we can filter by a label."""
|
||||||
sync_filter = json.dumps(
|
sync_filter = json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||||
|
|
||||||
def test_sync_filter_not_labels(self):
|
def test_sync_filter_not_labels(self) -> None:
|
||||||
"""Test that we can filter by the absence of a label."""
|
"""Test that we can filter by the absence of a label."""
|
||||||
sync_filter = json.dumps(
|
sync_filter = json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||||
events[2]["content"]["body"], "with two wrong labels", events[2]
|
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_sync_filter_labels_not_labels(self):
|
def test_sync_filter_labels_not_labels(self) -> None:
|
||||||
"""Test that we can filter by both a label and the absence of another label."""
|
"""Test that we can filter by both a label and the absence of another label."""
|
||||||
sync_filter = json.dumps(
|
sync_filter = json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||||
|
|
||||||
def _test_sync_filter_labels(self, sync_filter):
|
def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]:
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
tok = self.login("kermit", "test")
|
tok = self.login("kermit", "test")
|
||||||
|
|
||||||
|
@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
|
||||||
user_id = True
|
user_id = True
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
|
||||||
def test_sync_backwards_typing(self):
|
def test_sync_backwards_typing(self) -> None:
|
||||||
"""
|
"""
|
||||||
If the typing serial goes backwards and the typing handler is then reset
|
If the typing serial goes backwards and the typing handler is then reset
|
||||||
(such as when the master restarts and sets the typing serial to 0), we
|
(such as when the master restarts and sets the typing serial to 0), we
|
||||||
|
@ -298,7 +304,7 @@ class SyncKnockTestCase(
|
||||||
knock.register_servlets,
|
knock.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.url = "/sync?since=%s"
|
self.url = "/sync?since=%s"
|
||||||
self.next_batch = "s0"
|
self.next_batch = "s0"
|
||||||
|
@ -336,7 +342,7 @@ class SyncKnockTestCase(
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2403_enabled": True}})
|
@override_config({"experimental_features": {"msc2403_enabled": True}})
|
||||||
def test_knock_room_state(self):
|
def test_knock_room_state(self) -> None:
|
||||||
"""Tests that /sync returns state from a room after knocking on it."""
|
"""Tests that /sync returns state from a room after knocking on it."""
|
||||||
# Knock on a room
|
# Knock on a room
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.url = "/sync?since=%s"
|
self.url = "/sync?since=%s"
|
||||||
self.next_batch = "s0"
|
self.next_batch = "s0"
|
||||||
|
|
||||||
|
@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
|
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc2285_enabled": True}})
|
@override_config({"experimental_features": {"msc2285_enabled": True}})
|
||||||
def test_hidden_read_receipts(self):
|
def test_hidden_read_receipts(self) -> None:
|
||||||
# Send a message as the first user
|
# Send a message as the first user
|
||||||
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
|
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
|
||||||
|
|
||||||
|
@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_read_receipt_with_empty_body(
|
def test_read_receipt_with_empty_body(
|
||||||
self, name, user_agent: str, expected_status_code: int
|
self, name: str, user_agent: str, expected_status_code: int
|
||||||
):
|
) -> None:
|
||||||
# Send a message as the first user
|
# Send a message as the first user
|
||||||
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
|
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
|
||||||
|
|
||||||
|
@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, expected_status_code)
|
self.assertEqual(channel.code, expected_status_code)
|
||||||
|
|
||||||
def _get_read_receipt(self):
|
def _get_read_receipt(self) -> Optional[JsonDict]:
|
||||||
"""Syncs and returns the read receipt."""
|
"""Syncs and returns the read receipt."""
|
||||||
|
|
||||||
# Checks if event is a read receipt
|
# Checks if event is a read receipt
|
||||||
def is_read_receipt(event):
|
def is_read_receipt(event: JsonDict) -> bool:
|
||||||
return event["type"] == "m.receipt"
|
return event["type"] == "m.receipt"
|
||||||
|
|
||||||
# Sync
|
# Sync
|
||||||
|
@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
|
||||||
ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
|
ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
|
||||||
"ephemeral"
|
"ephemeral"
|
||||||
]["events"]
|
]["events"]
|
||||||
return next(filter(is_read_receipt, ephemeral_events), None)
|
receipt_event = filter(is_read_receipt, ephemeral_events)
|
||||||
|
return next(receipt_event, None)
|
||||||
|
|
||||||
|
|
||||||
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||||
receipts.register_servlets,
|
receipts.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.url = "/sync?since=%s"
|
self.url = "/sync?since=%s"
|
||||||
self.next_batch = "s0"
|
self.next_batch = "s0"
|
||||||
|
|
||||||
|
@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.tok,
|
tok=self.tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unread_counts(self):
|
def test_unread_counts(self) -> None:
|
||||||
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
|
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
|
||||||
|
|
||||||
# Check that our own messages don't increase the unread count.
|
# Check that our own messages don't increase the unread count.
|
||||||
|
@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self._check_unread_count(5)
|
self._check_unread_count(5)
|
||||||
|
|
||||||
def _check_unread_count(self, expected_count: int):
|
def _check_unread_count(self, expected_count: int) -> None:
|
||||||
"""Syncs and compares the unread count with the expected value."""
|
"""Syncs and compares the unread count with the expected value."""
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
|
||||||
sync.register_servlets,
|
sync.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_noop_sync_does_not_tightloop(self):
|
def test_noop_sync_does_not_tightloop(self) -> None:
|
||||||
"""If the sync times out, we shouldn't cache the result
|
"""If the sync times out, we shouldn't cache the result
|
||||||
|
|
||||||
Essentially a regression test for #8518.
|
Essentially a regression test for #8518.
|
||||||
|
@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
|
||||||
devices.register_servlets,
|
devices.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_user_with_no_rooms_receives_self_device_list_updates(self):
|
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
|
||||||
"""Tests that a user with no rooms still receives their own device list updates"""
|
"""Tests that a user with no rooms still receives their own device list updates"""
|
||||||
device_id = "TESTDEVICE"
|
device_id = "TESTDEVICE"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue