Misc typing fixes for tests, part 2 of N (#11330)
parent
e72135b9d3
commit
0dda1a7968
|
@ -0,0 +1 @@
|
||||||
|
Improve type annotations in Synapse's test suite.
|
|
@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_get_or_create_user_mau_not_blocked(self):
|
def test_get_or_create_user_mau_not_blocked(self):
|
||||||
self.store.count_monthly_users = Mock(
|
# Type ignore: mypy doesn't like us assigning to methods.
|
||||||
|
self.store.count_monthly_users = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
||||||
)
|
)
|
||||||
# Ensure does not throw exception
|
# Ensure does not throw exception
|
||||||
|
@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_get_or_create_user_mau_blocked(self):
|
def test_get_or_create_user_mau_blocked(self):
|
||||||
self.store.get_monthly_active_count = Mock(
|
# Type ignore: mypy doesn't like us assigning to methods.
|
||||||
|
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(self.lots_of_users)
|
return_value=make_awaitable(self.lots_of_users)
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.get_monthly_active_count = Mock(
|
# Type ignore: mypy doesn't like us assigning to methods.
|
||||||
|
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(self.hs.config.server.max_mau_value)
|
return_value=make_awaitable(self.hs.config.server.max_mau_value)
|
||||||
)
|
)
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
|
|
|
@ -28,11 +28,12 @@ from typing import (
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
overload,
|
||||||
)
|
)
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site
|
||||||
|
@ -55,6 +56,32 @@ class RestHelper:
|
||||||
site = attr.ib(type=Site)
|
site = attr.ib(type=Site)
|
||||||
auth_user_id = attr.ib()
|
auth_user_id = attr.ib()
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_room_as(
|
||||||
|
self,
|
||||||
|
room_creator: Optional[str] = ...,
|
||||||
|
is_public: Optional[bool] = ...,
|
||||||
|
room_version: Optional[str] = ...,
|
||||||
|
tok: Optional[str] = ...,
|
||||||
|
expect_code: Literal[200] = ...,
|
||||||
|
extra_content: Optional[Dict] = ...,
|
||||||
|
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
|
||||||
|
) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_room_as(
|
||||||
|
self,
|
||||||
|
room_creator: Optional[str] = ...,
|
||||||
|
is_public: Optional[bool] = ...,
|
||||||
|
room_version: Optional[str] = ...,
|
||||||
|
tok: Optional[str] = ...,
|
||||||
|
expect_code: int = ...,
|
||||||
|
extra_content: Optional[Dict] = ...,
|
||||||
|
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
|
||||||
|
) -> Optional[str]:
|
||||||
|
...
|
||||||
|
|
||||||
def create_room_as(
|
def create_room_as(
|
||||||
self,
|
self,
|
||||||
room_creator: Optional[str] = None,
|
room_creator: Optional[str] = None,
|
||||||
|
@ -64,7 +91,7 @@ class RestHelper:
|
||||||
expect_code: int = 200,
|
expect_code: int = 200,
|
||||||
extra_content: Optional[Dict] = None,
|
extra_content: Optional[Dict] = None,
|
||||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||||
) -> str:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Create a room.
|
Create a room.
|
||||||
|
|
||||||
|
@ -107,6 +134,8 @@ class RestHelper:
|
||||||
|
|
||||||
if expect_code == 200:
|
if expect_code == 200:
|
||||||
return channel.json_body["room_id"]
|
return channel.json_body["room_id"]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
||||||
self.change_membership(
|
self.change_membership(
|
||||||
|
@ -176,7 +205,7 @@ class RestHelper:
|
||||||
extra_data: Optional[dict] = None,
|
extra_data: Optional[dict] = None,
|
||||||
tok: Optional[str] = None,
|
tok: Optional[str] = None,
|
||||||
expect_code: int = 200,
|
expect_code: int = 200,
|
||||||
expect_errcode: str = None,
|
expect_errcode: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send a membership state event into a room.
|
Send a membership state event into a room.
|
||||||
|
@ -260,9 +289,7 @@ class RestHelper:
|
||||||
txn_id=None,
|
txn_id=None,
|
||||||
tok=None,
|
tok=None,
|
||||||
expect_code=200,
|
expect_code=200,
|
||||||
custom_headers: Optional[
|
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
|
||||||
] = None,
|
|
||||||
):
|
):
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
txn_id = "m%s" % (str(time.time()))
|
txn_id = "m%s" % (str(time.time()))
|
||||||
|
@ -509,7 +536,7 @@ class RestHelper:
|
||||||
went.
|
went.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cookies = {}
|
cookies: Dict[str, str] = {}
|
||||||
|
|
||||||
# if we're doing a ui auth, hit the ui auth redirect endpoint
|
# if we're doing a ui auth, hit the ui auth redirect endpoint
|
||||||
if ui_auth_session_id:
|
if ui_auth_session_id:
|
||||||
|
@ -631,7 +658,13 @@ class RestHelper:
|
||||||
|
|
||||||
# hit the redirect url again with the right Host header, which should now issue
|
# hit the redirect url again with the right Host header, which should now issue
|
||||||
# a cookie and redirect to the SSO provider.
|
# a cookie and redirect to the SSO provider.
|
||||||
location = channel.headers.getRawHeaders("Location")[0]
|
def get_location(channel: FakeChannel) -> str:
|
||||||
|
location_values = channel.headers.getRawHeaders("Location")
|
||||||
|
# Keep mypy happy by asserting that location_values is nonempty
|
||||||
|
assert location_values
|
||||||
|
return location_values[0]
|
||||||
|
|
||||||
|
location = get_location(channel)
|
||||||
parts = urllib.parse.urlsplit(location)
|
parts = urllib.parse.urlsplit(location)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
|
@ -645,7 +678,7 @@ class RestHelper:
|
||||||
|
|
||||||
assert channel.code == 302
|
assert channel.code == 302
|
||||||
channel.extract_cookies(cookies)
|
channel.extract_cookies(cookies)
|
||||||
return channel.headers.getRawHeaders("Location")[0]
|
return get_location(channel)
|
||||||
|
|
||||||
def initiate_sso_ui_auth(
|
def initiate_sso_ui_auth(
|
||||||
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
|
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
|
||||||
|
|
|
@ -24,6 +24,7 @@ from typing import (
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -226,7 +227,7 @@ def make_request(
|
||||||
path: Union[bytes, str],
|
path: Union[bytes, str],
|
||||||
content: Union[bytes, str, JsonDict] = b"",
|
content: Union[bytes, str, JsonDict] = b"",
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
request: Request = SynapseRequest,
|
request: Type[Request] = SynapseRequest,
|
||||||
shorthand: bool = True,
|
shorthand: bool = True,
|
||||||
federation_auth_origin: Optional[bytes] = None,
|
federation_auth_origin: Optional[bytes] = None,
|
||||||
content_is_form: bool = False,
|
content_is_form: bool = False,
|
||||||
|
|
|
@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse import events
|
from synapse import events
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
@ -95,16 +96,13 @@ def around(target):
|
||||||
return _around
|
return _around
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class TestCase(unittest.TestCase):
|
class TestCase(unittest.TestCase):
|
||||||
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
|
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
|
||||||
attributes on both itself and its individual test methods, to override the
|
attributes on both itself and its individual test methods, to override the
|
||||||
root logger's logging level while that test (case|method) runs."""
|
root logger's logging level while that test (case|method) runs."""
|
||||||
|
|
||||||
def __init__(self, methodName, *args, **kwargs):
|
def __init__(self, methodName: str):
|
||||||
super().__init__(methodName, *args, **kwargs)
|
super().__init__(methodName)
|
||||||
|
|
||||||
method = getattr(self, methodName)
|
method = getattr(self, methodName)
|
||||||
|
|
||||||
|
@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
|
||||||
Attributes:
|
Attributes:
|
||||||
servlets: List of servlet registration function.
|
servlets: List of servlet registration function.
|
||||||
user_id (str): The user ID to assume if auth is hijacked.
|
user_id (str): The user ID to assume if auth is hijacked.
|
||||||
hijack_auth (bool): Whether to hijack auth to return the user specified
|
hijack_auth: Whether to hijack auth to return the user specified
|
||||||
in user_id.
|
in user_id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hijack_auth = True
|
hijack_auth: ClassVar[bool] = True
|
||||||
needs_threadpool = False
|
needs_threadpool: ClassVar[bool] = False
|
||||||
servlets: ClassVar[List[RegisterServletsFunc]] = []
|
servlets: ClassVar[List[RegisterServletsFunc]] = []
|
||||||
|
|
||||||
def __init__(self, methodName, *args, **kwargs):
|
def __init__(self, methodName: str):
|
||||||
super().__init__(methodName, *args, **kwargs)
|
super().__init__(methodName)
|
||||||
|
|
||||||
# see if we have any additional config for this test
|
# see if we have any additional config for this test
|
||||||
method = getattr(self, methodName)
|
method = getattr(self, methodName)
|
||||||
|
@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase):
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_auth().get_user_by_req = get_user_by_req
|
# Type ignore: mypy doesn't like us assigning to methods.
|
||||||
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
|
||||||
self.hs.get_auth().get_access_token_from_request = Mock(
|
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
|
||||||
|
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
|
||||||
return_value="1234"
|
return_value="1234"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase):
|
||||||
path: Union[bytes, str],
|
path: Union[bytes, str],
|
||||||
content: Union[bytes, str, JsonDict] = b"",
|
content: Union[bytes, str, JsonDict] = b"",
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
request: Type[T] = SynapseRequest,
|
request: Type[Request] = SynapseRequest,
|
||||||
shorthand: bool = True,
|
shorthand: bool = True,
|
||||||
federation_auth_origin: Optional[bytes] = None,
|
federation_auth_origin: Optional[bytes] = None,
|
||||||
content_is_form: bool = False,
|
content_is_form: bool = False,
|
||||||
|
@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase):
|
||||||
nonce_str += b"\x00notadmin"
|
nonce_str += b"\x00notadmin"
|
||||||
|
|
||||||
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
|
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac_digest = want_mac.hexdigest()
|
||||||
|
|
||||||
body = json.dumps(
|
body = json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase):
|
||||||
"displayname": displayname,
|
"displayname": displayname,
|
||||||
"password": password,
|
"password": password,
|
||||||
"admin": admin,
|
"admin": admin,
|
||||||
"mac": want_mac,
|
"mac": want_mac_digest,
|
||||||
"inhibit_login": True,
|
"inhibit_login": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue