Add missing type hints to tests. (#15027)

pull/15030/head
Patrick Cloke 2023-02-08 14:52:37 -05:00 committed by GitHub
parent 55e4d27b36
commit 4eed7b2ede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 76 deletions

1
changelog.d/15027.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -69,27 +69,9 @@ disallow_untyped_defs = False
[mypy-tests.server_notices.test_resource_limits_server_notices] [mypy-tests.server_notices.test_resource_limits_server_notices]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.test_distributor]
disallow_untyped_defs = False
[mypy-tests.test_event_auth]
disallow_untyped_defs = False
[mypy-tests.test_federation] [mypy-tests.test_federation]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.test_mau]
disallow_untyped_defs = False
[mypy-tests.test_rust]
disallow_untyped_defs = False
[mypy-tests.test_test_utils]
disallow_untyped_defs = False
[mypy-tests.test_types]
disallow_untyped_defs = False
[mypy-tests.test_utils.*] [mypy-tests.test_utils.*]
disallow_untyped_defs = False disallow_untyped_defs = False

View File

@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase): class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.dist = Distributor() self.dist = Distributor()
def test_signal_dispatch(self): def test_signal_dispatch(self) -> None:
self.dist.declare("alert") self.dist.declare("alert")
observer = Mock() observer = Mock()
@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3) self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
def test_signal_catch(self): def test_signal_catch(self) -> None:
self.dist.declare("alarm") self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)] observers = [Mock() for i in (1, 2)]
@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1) self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str) self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
def test_signal_prereg(self): def test_signal_prereg(self) -> None:
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)
def test_signal_undeclared(self): def test_signal_undeclared(self) -> None:
def code(): def code() -> None:
self.dist.fire("notification") self.dist.fire("notification")
self.assertRaises(KeyError, code) self.assertRaises(KeyError, code)

View File

@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore: class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore""" """A stub implementation of the EventSourceStore"""
def __init__(self): def __init__(self) -> None:
self._store: Dict[str, EventBase] = {} self._store: Dict[str, EventBase] = {}
def add_event(self, event: EventBase): def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event self._store[event.event_id] = event
def add_events(self, events: Iterable[EventBase]): def add_events(self, events: Iterable[EventBase]) -> None:
for event in events: for event in events:
self._store[event.event_id] = event self._store[event.event_id] = event
@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase): class EventAuthTestCase(unittest.TestCase):
def test_rejected_auth_events(self): def test_rejected_auth_events(self) -> None:
""" """
Events that refer to rejected events in their auth events are rejected Events that refer to rejected events in their auth events are rejected
""" """
@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
) )
) )
def test_create_event_with_prev_events(self): def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected """A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event) event_auth.check_state_independent_auth_rules(event_store, bad_event)
) )
def test_duplicate_auth_events(self): def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected """Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2) event_auth.check_state_independent_auth_rules(event_store, bad_event2)
) )
def test_unexpected_auth_events(self): def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected """Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event) event_auth.check_state_independent_auth_rules(event_store, bad_event)
) )
def test_random_users_cannot_send_state_before_first_pl(self): def test_random_users_cannot_send_state_before_first_pl(self) -> None:
""" """
Check that, before the first PL lands, the creator is the only user Check that, before the first PL lands, the creator is the only user
that can send a state event. that can send a state event.
@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events, auth_events,
) )
def test_state_default_level(self): def test_state_default_level(self) -> None:
""" """
Check that users above the state_default level can send state and Check that users above the state_default level can send state and
those below cannot those below cannot
@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events, auth_events,
) )
def test_alias_event(self): def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6.""" """Alias events have special behavior up through room version 6."""
creator = "@creator:example.com" creator = "@creator:example.com"
other = "@other:example.com" other = "@other:example.com"
@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events, auth_events,
) )
def test_msc2432_alias_event(self): def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior.""" """After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com" creator = "@creator:example.com"
other = "@other:example.com" other = "@other:example.com"
@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
) )
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
def test_notifications(self, room_version: RoomVersion, allow_modification: bool): def test_notifications(
self, room_version: RoomVersion, allow_modification: bool
) -> None:
""" """
Notifications power levels get checked due to MSC2209. Notifications power levels get checked due to MSC2209.
""" """
@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events) event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
def test_join_rules_public(self): def test_join_rules_public(self) -> None:
""" """
Test joining a public room. Test joining a public room.
""" """
@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(), auth_events.values(),
) )
def test_join_rules_invite(self): def test_join_rules_invite(self) -> None:
""" """
Test joining an invite only room. Test joining an invite only room.
""" """
@ -835,7 +837,7 @@ def _power_levels_event(
) )
def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = { data = {
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version), **_maybe_get_event_id_dict_for_room_version(room_version),

View File

@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
from typing import List from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync from synapse.rest.client import register, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets] servlets = [register.register_servlets, sync.register_servlets]
def default_config(self): def default_config(self) -> JsonDict:
config = default_config("test") config = default_config("test")
config.update( config.update(
@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config return config
def prepare(self, reactor, clock, homeserver): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
def test_simple_deny_mau(self): def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated # Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
self.do_sync_for_user(token1) self.do_sync_for_user(token1)
@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403) self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_as_ignores_mau(self): def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU """Test that application services can still create users when the MAU
limit has been reached. This only works when application service limit has been reached. This only works when application service
user ip tracking is disabled. user ip tracking is disabled.
@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True) self.create_user("as_kermit4", token=as_token, appservice=True)
def test_allowed_after_a_month_mau(self): def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated # Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
self.do_sync_for_user(token1) self.do_sync_for_user(token1)
@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3) self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1}) @override_config({"mau_trial_days": 1})
def test_trial_delay(self): def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially # We should be able to register more than the limit initially
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
self.do_sync_for_user(token1) self.do_sync_for_user(token1)
@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1}) @override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self): def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1 self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially # We should be able to register more than the limit initially
@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter # max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
) )
def test_tracked_but_not_limited(self): def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the # Simply being able to create 2 users indicates that the
# limit was not reached. # limit was not reached.
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2}, "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
} }
) )
def test_as_trial_days(self): def test_as_trial_days(self) -> None:
user_tokens: List[str] = [] user_tokens: List[str] = []
def advance_time_and_sync(): def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61) self.reactor.advance(24 * 60 * 61)
for token in user_tokens: for token in user_tokens:
self.do_sync_for_user(token) self.do_sync_for_user(token)
@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
}, },
) )
def create_user(self, localpart, token=None, appservice=False): def create_user(
self, localpart: str, token: Optional[str] = None, appservice: bool = False
) -> str:
request_data = { request_data = {
"username": localpart, "username": localpart,
"password": "monkey", "password": "monkey",
@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token return access_token
def do_sync_for_user(self, token): def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token) channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200: if channel.code != 200:

View File

@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase): class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code.""" """Basic tests to ensure that we can call into Rust code."""
def test_basic(self): def test_basic(self) -> None:
result = sum_as_string(1, 2) result = sum_as_string(1, 2)
self.assertEqual("3", result) self.assertEqual("3", result)

View File

@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase): class MockClockTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.clock = MockClock() self.clock = MockClock()
def test_advance_time(self): def test_advance_time(self) -> None:
start_time = self.clock.time() start_time = self.clock.time()
self.clock.advance_time(20) self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time) self.assertEqual(20, self.clock.time() - start_time)
def test_later(self): def test_later(self) -> None:
invoked = [0, 0] invoked = [0, 0]
def _cb0(): def _cb0() -> None:
invoked[0] = 1 invoked[0] = 1
self.clock.call_later(10, _cb0) self.clock.call_later(10, _cb0)
def _cb1(): def _cb1() -> None:
invoked[1] = 1 invoked[1] = 1
self.clock.call_later(20, _cb1) self.clock.call_later(20, _cb1)
@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1]) self.assertTrue(invoked[1])
def test_cancel_later(self): def test_cancel_later(self) -> None:
invoked = [0, 0] invoked = [0, 0]
def _cb0(): def _cb0() -> None:
invoked[0] = 1 invoked[0] = 1
t0 = self.clock.call_later(10, _cb0) t0 = self.clock.call_later(10, _cb0)
def _cb1(): def _cb1() -> None:
invoked[1] = 1 invoked[1] = 1
self.clock.call_later(20, _cb1) self.clock.call_later(20, _cb1)

View File

@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase): class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self): def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test") user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart) self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain) self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user)) self.assertEqual(True, self.hs.is_mine(user))
def test_parse_rejects_empty_id(self): def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("") UserID.from_string("")
def test_parse_rejects_missing_sigil(self): def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com") UserID.from_string("alice:example.com")
def test_parse_rejects_missing_separator(self): def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com") UserID.from_string("@alice.example.com")
def test_validation_rejects_missing_domain(self): def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:")) self.assertFalse(UserID.is_valid("@alice:"))
def test_build(self): def test_build(self) -> None:
user = UserID("5678efgh", "my.domain") user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain") self.assertEqual(user.to_string(), "@5678efgh:my.domain")
def test_compare(self): def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain") userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain") userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain") userB = UserID.from_string("@userB:my.domain")
@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase): class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self): def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test") room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart) self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain) self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room)) self.assertEqual(True, self.hs.is_mine(room))
def test_build(self): def test_build(self) -> None:
room = RoomAlias("channel", "my.domain") room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain") self.assertEqual(room.to_string(), "#channel:my.domain")
def test_validate(self): def test_validate(self) -> None:
id_string = "#test:domain,test" id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string)) self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase): class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self): def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
def testUpperCase(self): def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234") self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual( self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True), map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234", "t_e_s_t__1234",
) )
def testSymbols(self): def test_symbols(self) -> None:
self.assertEqual( self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
) )
def testLeadingUnderscore(self): def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234") self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
def testNonAscii(self): def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes # this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")