Add missing type hints for tests.unittest. (#13397)
parent
502f075e96
commit
922b771337
|
@ -0,0 +1 @@
|
|||
Adding missing type hints to tests.
|
|
@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
|||
|
||||
return config
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||
|
||||
self.denied_user_id = self.register_user("denied", "pass")
|
||||
self.denied_access_token = self.login("denied", "pass")
|
||||
|
||||
return hs
|
||||
|
||||
def test_denied_without_publication_permission(self) -> None:
|
||||
"""
|
||||
Try to create a room, register an alias for it, and publish it,
|
||||
|
@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
|||
|
||||
servlets = [directory.register_servlets, room.register_servlets]
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
||||
) -> HomeServer:
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
channel = self.make_request(
|
||||
|
@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
|||
self.room_list_handler = hs.get_room_list_handler()
|
||||
self.directory_handler = hs.get_directory_handler()
|
||||
|
||||
return hs
|
||||
|
||||
def test_disabling_room_list(self) -> None:
|
||||
self.room_list_handler.enable_room_list_search = True
|
||||
self.directory_handler.enable_room_list_search = True
|
||||
|
|
|
@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
|||
participated, bundled_aggregations.get("current_user_participated")
|
||||
)
|
||||
# The latest thread event has some fields that don't matter.
|
||||
self.assertIn("latest_event", bundled_aggregations)
|
||||
self.assert_dict(
|
||||
{
|
||||
"content": {
|
||||
|
@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
|||
"sender": self.user2_id,
|
||||
"type": "m.room.test",
|
||||
},
|
||||
bundled_aggregations.get("latest_event"),
|
||||
bundled_aggregations["latest_event"],
|
||||
)
|
||||
|
||||
return assert_thread
|
||||
|
@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
|||
self.assertEqual(2, bundled_aggregations.get("count"))
|
||||
self.assertTrue(bundled_aggregations.get("current_user_participated"))
|
||||
# The latest thread event has some fields that don't matter.
|
||||
self.assertIn("latest_event", bundled_aggregations)
|
||||
self.assert_dict(
|
||||
{
|
||||
"content": {
|
||||
|
@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
|||
"sender": self.user_id,
|
||||
"type": "m.room.test",
|
||||
},
|
||||
bundled_aggregations.get("latest_event"),
|
||||
bundled_aggregations["latest_event"],
|
||||
)
|
||||
# Check the unsigned field on the latest event.
|
||||
self.assert_dict(
|
||||
|
|
|
@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
|
|||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertCountEqual(
|
||||
[state_event["type"] for state_event in channel.json_body],
|
||||
[state_event["type"] for state_event in channel.json_list],
|
||||
{
|
||||
"m.room.create",
|
||||
"m.room.power_levels",
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
|
@ -121,7 +122,15 @@ class FakeChannel:
|
|||
|
||||
@property
|
||||
def json_body(self) -> JsonDict:
|
||||
return json.loads(self.text_body)
|
||||
body = json.loads(self.text_body)
|
||||
assert isinstance(body, dict)
|
||||
return body
|
||||
|
||||
@property
|
||||
def json_list(self) -> List[JsonDict]:
|
||||
body = json.loads(self.text_body)
|
||||
assert isinstance(body, list)
|
||||
return body
|
||||
|
||||
@property
|
||||
def text_body(self) -> str:
|
||||
|
|
|
@ -28,6 +28,7 @@ from typing import (
|
|||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
|
@ -39,7 +40,7 @@ from unittest.mock import Mock, patch
|
|||
import canonicaljson
|
||||
import signedjson.key
|
||||
import unpaddedbase64
|
||||
from typing_extensions import Protocol
|
||||
from typing_extensions import Concatenate, ParamSpec, Protocol
|
||||
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
from twisted.python.failure import Failure
|
||||
|
@ -67,7 +68,7 @@ from synapse.logging.context import (
|
|||
from synapse.rest import RegisterServletsFunc
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
|
||||
|
@ -88,6 +89,10 @@ setup_logging()
|
|||
TV = TypeVar("TV")
|
||||
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
class _TypedFailure(Generic[_ExcType], Protocol):
|
||||
"""Extension to twisted.Failure, where the 'value' has a certain type."""
|
||||
|
@ -97,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
|
|||
...
|
||||
|
||||
|
||||
def around(target):
|
||||
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
|
||||
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
||||
given instance with another piece of code.
|
||||
|
||||
|
@ -106,11 +111,11 @@ def around(target):
|
|||
return orig(*args, **kwargs)
|
||||
"""
|
||||
|
||||
def _around(code):
|
||||
def _around(code: Callable[Concatenate[S, P], R]) -> None:
|
||||
name = code.__name__
|
||||
orig = getattr(target, name)
|
||||
|
||||
def new(*args, **kwargs):
|
||||
def new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
return code(orig, *args, **kwargs)
|
||||
|
||||
setattr(target, name, new)
|
||||
|
@ -131,7 +136,7 @@ class TestCase(unittest.TestCase):
|
|||
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
||||
|
||||
@around(self)
|
||||
def setUp(orig):
|
||||
def setUp(orig: Callable[[], R]) -> R:
|
||||
# if we're not starting in the sentinel logcontext, then to be honest
|
||||
# all future bets are off.
|
||||
if current_context():
|
||||
|
@ -144,7 +149,7 @@ class TestCase(unittest.TestCase):
|
|||
if level is not None and old_level != level:
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig):
|
||||
def tearDown(orig: Callable[[], R]) -> R:
|
||||
ret = orig()
|
||||
logging.getLogger().setLevel(old_level)
|
||||
return ret
|
||||
|
@ -158,7 +163,7 @@ class TestCase(unittest.TestCase):
|
|||
return orig()
|
||||
|
||||
@around(self)
|
||||
def tearDown(orig):
|
||||
def tearDown(orig: Callable[[], R]) -> R:
|
||||
ret = orig()
|
||||
# force a GC to workaround problems with deferreds leaking logcontexts when
|
||||
# they are GCed (see the logcontext docs)
|
||||
|
@ -167,7 +172,7 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
return ret
|
||||
|
||||
def assertObjectHasAttributes(self, attrs, obj):
|
||||
def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
|
||||
"""Asserts that the given object has each of the attributes given, and
|
||||
that the value of each matches according to assertEqual."""
|
||||
for key in attrs.keys():
|
||||
|
@ -178,12 +183,12 @@ class TestCase(unittest.TestCase):
|
|||
except AssertionError as e:
|
||||
raise (type(e))(f"Assert error for '.{key}':") from e
|
||||
|
||||
def assert_dict(self, required, actual):
|
||||
def assert_dict(self, required: dict, actual: dict) -> None:
|
||||
"""Does a partial assert of a dict.
|
||||
|
||||
Args:
|
||||
required (dict): The keys and value which MUST be in 'actual'.
|
||||
actual (dict): The test result. Extra keys will not be checked.
|
||||
required: The keys and value which MUST be in 'actual'.
|
||||
actual: The test result. Extra keys will not be checked.
|
||||
"""
|
||||
for key in required:
|
||||
self.assertEqual(
|
||||
|
@ -191,31 +196,31 @@ class TestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
|
||||
def DEBUG(target):
|
||||
def DEBUG(target: TV) -> TV:
|
||||
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|
||||
Can apply to either a TestCase or an individual test method."""
|
||||
target.loglevel = logging.DEBUG
|
||||
target.loglevel = logging.DEBUG # type: ignore[attr-defined]
|
||||
return target
|
||||
|
||||
|
||||
def INFO(target):
|
||||
def INFO(target: TV) -> TV:
|
||||
"""A decorator to set the .loglevel attribute to logging.INFO.
|
||||
Can apply to either a TestCase or an individual test method."""
|
||||
target.loglevel = logging.INFO
|
||||
target.loglevel = logging.INFO # type: ignore[attr-defined]
|
||||
return target
|
||||
|
||||
|
||||
def logcontext_clean(target):
|
||||
def logcontext_clean(target: TV) -> TV:
|
||||
"""A decorator which marks the TestCase or method as 'logcontext_clean'
|
||||
|
||||
... ie, any logcontext errors should cause a test failure
|
||||
"""
|
||||
|
||||
def logcontext_error(msg):
|
||||
def logcontext_error(msg: str) -> NoReturn:
|
||||
raise AssertionError("logcontext error: %s" % (msg))
|
||||
|
||||
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
|
||||
return patcher(target)
|
||||
return patcher(target) # type: ignore[call-overload]
|
||||
|
||||
|
||||
class HomeserverTestCase(TestCase):
|
||||
|
@ -255,7 +260,7 @@ class HomeserverTestCase(TestCase):
|
|||
method = getattr(self, methodName)
|
||||
self._extra_config = getattr(method, "_extra_config", None)
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
"""
|
||||
Set up the TestCase by calling the homeserver constructor, optionally
|
||||
hijacking the authentication system to return a fixed user, and then
|
||||
|
@ -306,7 +311,9 @@ class HomeserverTestCase(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||
async def get_user_by_access_token(
|
||||
token: Optional[str] = None, allow_guest: bool = False
|
||||
) -> JsonDict:
|
||||
assert self.helper.auth_user_id is not None
|
||||
return {
|
||||
"user": UserID.from_string(self.helper.auth_user_id),
|
||||
|
@ -314,7 +321,11 @@ class HomeserverTestCase(TestCase):
|
|||
"is_guest": False,
|
||||
}
|
||||
|
||||
async def get_user_by_req(request, allow_guest=False):
|
||||
async def get_user_by_req(
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
assert self.helper.auth_user_id is not None
|
||||
return create_requester(
|
||||
UserID.from_string(self.helper.auth_user_id),
|
||||
|
@ -339,11 +350,11 @@ class HomeserverTestCase(TestCase):
|
|||
if hasattr(self, "prepare"):
|
||||
self.prepare(self.reactor, self.clock, self.hs)
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
# Reset to not use frozen dicts.
|
||||
events.USE_FROZEN_DICTS = False
|
||||
|
||||
def wait_on_thread(self, deferred, timeout=10):
|
||||
def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
|
||||
"""
|
||||
Wait until a Deferred is done, where it's waiting on a real thread.
|
||||
"""
|
||||
|
@ -374,7 +385,7 @@ class HomeserverTestCase(TestCase):
|
|||
clock (synapse.util.Clock): The Clock, associated with the reactor.
|
||||
|
||||
Returns:
|
||||
A homeserver (synapse.server.HomeServer) suitable for testing.
|
||||
A homeserver suitable for testing.
|
||||
|
||||
Function to be overridden in subclasses.
|
||||
"""
|
||||
|
@ -408,7 +419,7 @@ class HomeserverTestCase(TestCase):
|
|||
"/_synapse/admin": servlet_resource,
|
||||
}
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> JsonDict:
|
||||
"""
|
||||
Get a default HomeServer config dict.
|
||||
"""
|
||||
|
@ -421,7 +432,9 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
"""
|
||||
Prepare for the test. This involves things like mocking out parts of
|
||||
the homeserver, or building test data common across the whole test
|
||||
|
@ -519,7 +532,7 @@ class HomeserverTestCase(TestCase):
|
|||
config_obj.parse_config_dict(config, "", "")
|
||||
kwargs["config"] = config_obj
|
||||
|
||||
async def run_bg_updates():
|
||||
async def run_bg_updates() -> None:
|
||||
with LoggingContext("run_bg_updates"):
|
||||
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
||||
|
||||
|
@ -538,11 +551,7 @@ class HomeserverTestCase(TestCase):
|
|||
"""
|
||||
self.reactor.pump([by] * 100)
|
||||
|
||||
def get_success(
|
||||
self,
|
||||
d: Awaitable[TV],
|
||||
by: float = 0.0,
|
||||
) -> TV:
|
||||
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||
self.pump(by=by)
|
||||
return self.successResultOf(deferred)
|
||||
|
@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||
OTHER_SERVER_NAME = "other.example.com"
|
||||
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
# poke the other server's signing key into the key store, so that we don't
|
||||
|
@ -879,7 +888,7 @@ def _auth_header_for_request(
|
|||
)
|
||||
|
||||
|
||||
def override_config(extra_config):
|
||||
def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
|
||||
"""A decorator which can be applied to test functions to give additional HS config
|
||||
|
||||
For use
|
||||
|
@ -892,12 +901,13 @@ def override_config(extra_config):
|
|||
...
|
||||
|
||||
Args:
|
||||
extra_config(dict): Additional config settings to be merged into the default
|
||||
extra_config: Additional config settings to be merged into the default
|
||||
config dict before instantiating the test homeserver.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
func._extra_config = extra_config
|
||||
def decorator(func: TV) -> TV:
|
||||
# This attribute is being defined.
|
||||
func._extra_config = extra_config # type: ignore[attr-defined]
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
|
Loading…
Reference in New Issue