Add type hints for `tests/unittest.py`. (#12347)
In particular, add type hints for get_success and friends, which are then helpful in a bunch of places.hs/deactivate-leave-metadata
parent
33ebee47e4
commit
f0b03186d9
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations for `tests/unittest.py`.
|
1
mypy.ini
1
mypy.ini
|
@ -83,7 +83,6 @@ exclude = (?x)
|
||||||
|tests/test_server.py
|
|tests/test_server.py
|
||||||
|tests/test_state.py
|
|tests/test_state.py
|
||||||
|tests/test_terms_auth.py
|
|tests/test_terms_auth.py
|
||||||
|tests/unittest.py
|
|
||||||
|tests/util/caches/test_cached_call.py
|
|tests/util/caches/test_cached_call.py
|
||||||
|tests/util/caches/test_deferred_cache.py
|
|tests/util/caches/test_deferred_cache.py
|
||||||
|tests/util/caches/test_descriptors.py
|
|tests/util/caches/test_descriptors.py
|
||||||
|
|
|
@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res = e.value.code
|
res = e.value.code
|
||||||
self.assertEqual(res, 400)
|
self.assertEqual(res, 400)
|
||||||
|
|
||||||
res = self.get_success(self.handler.query_local_devices({local_user: None}))
|
query_res = self.get_success(
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.handler.query_local_devices({local_user: None})
|
||||||
|
)
|
||||||
|
self.assertDictEqual(query_res, {local_user: {}})
|
||||||
|
|
||||||
def test_upload_signatures(self) -> None:
|
def test_upload_signatures(self) -> None:
|
||||||
"""should check signatures that are uploaded"""
|
"""should check signatures that are uploaded"""
|
||||||
|
|
|
@ -375,7 +375,8 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
member_event.signatures = member_event_dict["signatures"]
|
member_event.signatures = member_event_dict["signatures"]
|
||||||
|
|
||||||
# Add the new member_event to the StateMap
|
# Add the new member_event to the StateMap
|
||||||
prev_state_map[
|
updated_state_map = dict(prev_state_map)
|
||||||
|
updated_state_map[
|
||||||
(member_event.type, member_event.state_key)
|
(member_event.type, member_event.state_key)
|
||||||
] = member_event.event_id
|
] = member_event.event_id
|
||||||
auth_events.append(member_event)
|
auth_events.append(member_event)
|
||||||
|
@ -399,7 +400,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
prev_event_ids=message_event_dict["prev_events"],
|
prev_event_ids=message_event_dict["prev_events"],
|
||||||
auth_event_ids=self._event_auth_handler.compute_auth_events(
|
auth_event_ids=self._event_auth_handler.compute_auth_events(
|
||||||
builder,
|
builder,
|
||||||
prev_state_map,
|
updated_state_map,
|
||||||
for_verification=False,
|
for_verification=False,
|
||||||
),
|
),
|
||||||
depth=message_event_dict["depth"],
|
depth=message_event_dict["depth"],
|
||||||
|
|
|
@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
req = Mock(spec=["cookies"])
|
req = Mock(spec=["cookies"])
|
||||||
req.cookies = []
|
req.cookies = []
|
||||||
|
|
||||||
url = self.get_success(
|
url = urlparse(
|
||||||
self.provider.handle_redirect_request(req, b"http://client/redirect")
|
self.get_success(
|
||||||
|
self.provider.handle_redirect_request(req, b"http://client/redirect")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
url = urlparse(url)
|
|
||||||
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
||||||
|
|
||||||
self.assertEqual(url.scheme, auth_endpoint.scheme)
|
self.assertEqual(url.scheme, auth_endpoint.scheme)
|
||||||
|
|
|
@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.handler.handle_local_profile_change(regular_user_id, profile_info)
|
self.handler.handle_local_profile_change(regular_user_id, profile_info)
|
||||||
)
|
)
|
||||||
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
|
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
|
||||||
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == display_name)
|
self.assertTrue(profile["display_name"] == display_name)
|
||||||
|
|
||||||
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
|
||||||
|
@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# profile is in directory
|
# profile is in directory
|
||||||
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
|
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
|
||||||
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == display_name)
|
self.assertTrue(profile["display_name"] == display_name)
|
||||||
|
|
||||||
# deactivate user
|
# deactivate user
|
||||||
|
|
|
@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info["quarantined_by"])
|
||||||
|
|
||||||
# quarantining
|
# quarantining
|
||||||
|
@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(channel.json_body)
|
self.assertFalse(channel.json_body)
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["quarantined_by"])
|
self.assertTrue(media_info["quarantined_by"])
|
||||||
|
|
||||||
# remove from quarantine
|
# remove from quarantine
|
||||||
|
@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(channel.json_body)
|
self.assertFalse(channel.json_body)
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info["quarantined_by"])
|
||||||
|
|
||||||
def test_quarantine_protected_media(self) -> None:
|
def test_quarantine_protected_media(self) -> None:
|
||||||
|
@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# verify protection
|
# verify protection
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["safe_from_quarantine"])
|
self.assertTrue(media_info["safe_from_quarantine"])
|
||||||
|
|
||||||
# quarantining
|
# quarantining
|
||||||
|
@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# verify that is not in quarantine
|
# verify that is not in quarantine
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["quarantined_by"])
|
self.assertFalse(media_info["quarantined_by"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["safe_from_quarantine"])
|
self.assertFalse(media_info["safe_from_quarantine"])
|
||||||
|
|
||||||
# protect
|
# protect
|
||||||
|
@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(channel.json_body)
|
self.assertFalse(channel.json_body)
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertTrue(media_info["safe_from_quarantine"])
|
self.assertTrue(media_info["safe_from_quarantine"])
|
||||||
|
|
||||||
# unprotect
|
# unprotect
|
||||||
|
@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(channel.json_body)
|
self.assertFalse(channel.json_body)
|
||||||
|
|
||||||
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
media_info = self.get_success(self.store.get_local_media(self.media_id))
|
||||||
|
assert media_info is not None
|
||||||
self.assertFalse(media_info["safe_from_quarantine"])
|
self.assertFalse(media_info["safe_from_quarantine"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.store.get_pushers_by({"user_name": "@bob:test"})
|
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual("@bob:test", pushers[0].user_name)
|
self.assertEqual("@bob:test", pushers[0].user_name)
|
||||||
|
|
||||||
|
@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.store.get_pushers_by({"user_name": "@bob:test"})
|
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 0)
|
self.assertEqual(len(pushers), 0)
|
||||||
|
|
||||||
def test_set_password(self) -> None:
|
def test_set_password(self) -> None:
|
||||||
|
@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# is in user directory
|
# is in user directory
|
||||||
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
||||||
|
assert profile is not None
|
||||||
self.assertTrue(profile["display_name"] == "User")
|
self.assertTrue(profile["display_name"] == "User")
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
|
@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.store.get_user_by_access_token(other_user_token)
|
self.store.get_user_by_access_token(other_user_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
|
||||||
# The user starts off as not shadow-banned.
|
# The user starts off as not shadow-banned.
|
||||||
other_user_token = self.login("user", "pass")
|
other_user_token = self.login("user", "pass")
|
||||||
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
||||||
|
assert result is not None
|
||||||
self.assertFalse(result.shadow_banned)
|
self.assertFalse(result.shadow_banned)
|
||||||
|
|
||||||
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
|
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
|
||||||
|
@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the user is shadow-banned (and the cache was cleared).
|
# Ensure the user is shadow-banned (and the cache was cleared).
|
||||||
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
||||||
|
assert result is not None
|
||||||
self.assertTrue(result.shadow_banned)
|
self.assertTrue(result.shadow_banned)
|
||||||
|
|
||||||
# Un-shadow-ban the user.
|
# Un-shadow-ban the user.
|
||||||
|
@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the user is no longer shadow-banned (and the cache was cleared).
|
# Ensure the user is no longer shadow-banned (and the cache was cleared).
|
||||||
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
||||||
|
assert result is not None
|
||||||
self.assertFalse(result.shadow_banned)
|
self.assertFalse(result.shadow_banned)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from io import SEEK_END, BytesIO
|
from io import SEEK_END, BytesIO
|
||||||
from typing import (
|
from typing import (
|
||||||
AnyStr,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -86,6 +85,9 @@ from tests.utils import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# the type of thing that can be passed into `make_request` in the headers list
|
||||||
|
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
||||||
|
|
||||||
|
|
||||||
class TimedOutException(Exception):
|
class TimedOutException(Exception):
|
||||||
"""
|
"""
|
||||||
|
@ -260,7 +262,7 @@ def make_request(
|
||||||
federation_auth_origin: Optional[bytes] = None,
|
federation_auth_origin: Optional[bytes] = None,
|
||||||
content_is_form: bool = False,
|
content_is_form: bool = False,
|
||||||
await_result: bool = True,
|
await_result: bool = True,
|
||||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||||
client_ip: str = "127.0.0.1",
|
client_ip: str = "127.0.0.1",
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
# First to acquire this lock, so it should complete
|
# First to acquire this lock, so it should complete
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
self.assertIsNotNone(lock)
|
assert lock is not None
|
||||||
|
|
||||||
# Enter the context manager
|
# Enter the context manager
|
||||||
self.get_success(lock.__aenter__())
|
self.get_success(lock.__aenter__())
|
||||||
|
@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# We can now acquire the lock again.
|
# We can now acquire the lock again.
|
||||||
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
self.assertIsNotNone(lock3)
|
assert lock3 is not None
|
||||||
self.get_success(lock3.__aenter__())
|
self.get_success(lock3.__aenter__())
|
||||||
self.get_success(lock3.__aexit__(None, None, None))
|
self.get_success(lock3.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
"""Test that we don't time out locks while they're still active"""
|
"""Test that we don't time out locks while they're still active"""
|
||||||
|
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
self.assertIsNotNone(lock)
|
assert lock is not None
|
||||||
|
|
||||||
self.get_success(lock.__aenter__())
|
self.get_success(lock.__aenter__())
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
|
||||||
"""Test that we time out locks if they're not updated for ages"""
|
"""Test that we time out locks if they're not updated for ages"""
|
||||||
|
|
||||||
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
|
||||||
self.assertIsNotNone(lock)
|
assert lock is not None
|
||||||
|
|
||||||
self.get_success(lock.__aenter__())
|
self.get_success(lock.__aenter__())
|
||||||
|
|
||||||
|
|
|
@ -358,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self._insert_txn(service.id, 12, other_events))
|
self.get_success(self._insert_txn(service.id, 12, other_events))
|
||||||
|
|
||||||
txn = self.get_success(self.store.get_oldest_unsent_txn(service))
|
txn = self.get_success(self.store.get_oldest_unsent_txn(service))
|
||||||
|
assert txn is not None
|
||||||
self.assertEqual(service, txn.service)
|
self.assertEqual(service, txn.service)
|
||||||
self.assertEqual(10, txn.id)
|
self.assertEqual(10, txn.id)
|
||||||
self.assertEqual(events, txn.events)
|
self.assertEqual(events, txn.events)
|
||||||
|
|
|
@ -22,10 +22,11 @@ import secrets
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AnyStr,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
|
||||||
import canonicaljson
|
import canonicaljson
|
||||||
import signedjson.key
|
import signedjson.key
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -49,7 +51,7 @@ from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse import events
|
from synapse import events
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||||
|
@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
|
||||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
from tests.server import (
|
||||||
|
CustomHeaderType,
|
||||||
|
FakeChannel,
|
||||||
|
get_clock,
|
||||||
|
make_request,
|
||||||
|
setup_test_homeserver,
|
||||||
|
)
|
||||||
from tests.test_utils import event_injection, setup_awaitable_errors
|
from tests.test_utils import event_injection, setup_awaitable_errors
|
||||||
from tests.test_utils.logging_setup import setup_logging
|
from tests.test_utils.logging_setup import setup_logging
|
||||||
from tests.utils import default_config, setupdb
|
from tests.utils import default_config, setupdb
|
||||||
|
@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
|
||||||
setupdb()
|
setupdb()
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
TV = TypeVar("TV")
|
||||||
|
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _TypedFailure(Generic[_ExcType], Protocol):
|
||||||
|
"""Extension to twisted.Failure, where the 'value' has a certain type."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> _ExcType:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def around(target):
|
def around(target):
|
||||||
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
||||||
|
@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
if hasattr(self, "user_id"):
|
if hasattr(self, "user_id"):
|
||||||
if self.hijack_auth:
|
if self.hijack_auth:
|
||||||
|
assert self.helper.auth_user_id is not None
|
||||||
|
|
||||||
# We need a valid token ID to satisfy foreign key constraints.
|
# We need a valid token ID to satisfy foreign key constraints.
|
||||||
token_id = self.get_success(
|
token_id = self.get_success(
|
||||||
|
@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||||
|
assert self.helper.auth_user_id is not None
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.helper.auth_user_id),
|
"user": UserID.from_string(self.helper.auth_user_id),
|
||||||
"token_id": token_id,
|
"token_id": token_id,
|
||||||
|
@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_user_by_req(request, allow_guest=False, rights="access"):
|
async def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||||
|
assert self.helper.auth_user_id is not None
|
||||||
return create_requester(
|
return create_requester(
|
||||||
UserID.from_string(self.helper.auth_user_id),
|
UserID.from_string(self.helper.auth_user_id),
|
||||||
token_id,
|
token_id,
|
||||||
|
@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.needs_threadpool:
|
if self.needs_threadpool:
|
||||||
self.reactor.threadpool = ThreadPool()
|
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
|
||||||
self.addCleanup(self.reactor.threadpool.stop)
|
self.addCleanup(self.reactor.threadpool.stop)
|
||||||
self.reactor.threadpool.start()
|
self.reactor.threadpool.start()
|
||||||
|
|
||||||
|
@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
|
||||||
federation_auth_origin: Optional[bytes] = None,
|
federation_auth_origin: Optional[bytes] = None,
|
||||||
content_is_form: bool = False,
|
content_is_form: bool = False,
|
||||||
await_result: bool = True,
|
await_result: bool = True,
|
||||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||||
client_ip: str = "127.0.0.1",
|
client_ip: str = "127.0.0.1",
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""
|
"""
|
||||||
|
@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def pump(self, by=0.0):
|
def pump(self, by: float = 0.0) -> None:
|
||||||
"""
|
"""
|
||||||
Pump the reactor enough that Deferreds will fire.
|
Pump the reactor enough that Deferreds will fire.
|
||||||
"""
|
"""
|
||||||
self.reactor.pump([by] * 100)
|
self.reactor.pump([by] * 100)
|
||||||
|
|
||||||
def get_success(self, d, by=0.0):
|
def get_success(
|
||||||
deferred: Deferred[TV] = ensureDeferred(d)
|
self,
|
||||||
|
d: Awaitable[TV],
|
||||||
|
by: float = 0.0,
|
||||||
|
) -> TV:
|
||||||
|
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||||
self.pump(by=by)
|
self.pump(by=by)
|
||||||
return self.successResultOf(deferred)
|
return self.successResultOf(deferred)
|
||||||
|
|
||||||
def get_failure(self, d, exc):
|
def get_failure(
|
||||||
|
self, d: Awaitable[Any], exc: Type[_ExcType]
|
||||||
|
) -> _TypedFailure[_ExcType]:
|
||||||
"""
|
"""
|
||||||
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
|
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
|
||||||
"""
|
"""
|
||||||
deferred: Deferred[Any] = ensureDeferred(d)
|
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
|
||||||
self.pump()
|
self.pump()
|
||||||
return self.failureResultOf(deferred, exc)
|
return self.failureResultOf(deferred, exc)
|
||||||
|
|
||||||
def get_success_or_raise(self, d, by=0.0):
|
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||||
"""Drive deferred to completion and return result or raise exception
|
"""Drive deferred to completion and return result or raise exception
|
||||||
on failure.
|
on failure.
|
||||||
"""
|
"""
|
||||||
deferred: Deferred[TV] = ensureDeferred(d)
|
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||||
|
|
||||||
results: list = []
|
results: list = []
|
||||||
deferred.addBoth(results.append)
|
deferred.addBoth(results.append)
|
||||||
|
@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
def login(
|
def login(
|
||||||
self,
|
self,
|
||||||
username,
|
username: str,
|
||||||
password,
|
password: str,
|
||||||
device_id=None,
|
device_id: Optional[str] = None,
|
||||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log in a user, and get an access token. Requires the Login API be
|
Log in a user, and get an access token. Requires the Login API be
|
||||||
registered.
|
registered.
|
||||||
|
@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def create_and_send_event(
|
def create_and_send_event(
|
||||||
self, room_id, user, soft_failed=False, prev_event_ids=None
|
self,
|
||||||
):
|
room_id: str,
|
||||||
|
user: UserID,
|
||||||
|
soft_failed: bool = False,
|
||||||
|
prev_event_ids: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create and send an event.
|
Create and send an event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
soft_failed (bool): Whether to create a soft failed event or not
|
soft_failed: Whether to create a soft failed event or not
|
||||||
prev_event_ids (list[str]|None): Explicitly set the prev events,
|
prev_event_ids: Explicitly set the prev events,
|
||||||
or if None just use the default
|
or if None just use the default
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The new event's ID.
|
The new event's ID.
|
||||||
"""
|
"""
|
||||||
event_creator = self.hs.get_event_creation_handler()
|
event_creator = self.hs.get_event_creation_handler()
|
||||||
requester = create_requester(user)
|
requester = create_requester(user)
|
||||||
|
@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
return event.event_id
|
return event.event_id
|
||||||
|
|
||||||
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
|
def inject_room_member(self, room: str, user: str, membership: str) -> None:
|
||||||
"""
|
"""
|
||||||
Inject a membership event into a room.
|
Inject a membership event into a room.
|
||||||
|
|
||||||
|
@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||||
path: str,
|
path: str,
|
||||||
content: Optional[JsonDict] = None,
|
content: Optional[JsonDict] = None,
|
||||||
await_result: bool = True,
|
await_result: bool = True,
|
||||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
||||||
client_ip: str = "127.0.0.1",
|
client_ip: str = "127.0.0.1",
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""Make an inbound signed federation request to this server
|
"""Make an inbound signed federation request to this server
|
||||||
|
@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||||
self.site,
|
self.site,
|
||||||
method=method,
|
method=method,
|
||||||
path=path,
|
path=path,
|
||||||
content=content,
|
content=content or "",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
await_result=await_result,
|
await_result=await_result,
|
||||||
custom_headers=custom_headers,
|
custom_headers=custom_headers,
|
||||||
|
@ -878,9 +910,6 @@ def override_config(extra_config):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
TV = TypeVar("TV")
|
|
||||||
|
|
||||||
|
|
||||||
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
|
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
|
||||||
"""A test decorator which will skip the decorated test unless a condition is set
|
"""A test decorator which will skip the decorated test unless a condition is set
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue