Misc typing fixes for tests, part 2 of N (#11330)
							parent
							
								
									e72135b9d3
								
							
						
					
					
						commit
						0dda1a7968
					
				|  | @ -0,0 +1 @@ | |||
| Improve type annotations in Synapse's test suite. | ||||
|  | @ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @override_config({"limit_usage_by_mau": True}) | ||||
|     def test_get_or_create_user_mau_not_blocked(self): | ||||
|         self.store.count_monthly_users = Mock( | ||||
|         # Type ignore: mypy doesn't like us assigning to methods. | ||||
|         self.store.count_monthly_users = Mock(  # type: ignore[assignment] | ||||
|             return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | ||||
|         ) | ||||
|         # Ensure does not throw exception | ||||
|  | @ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|     @override_config({"limit_usage_by_mau": True}) | ||||
|     def test_get_or_create_user_mau_blocked(self): | ||||
|         self.store.get_monthly_active_count = Mock( | ||||
|         # Type ignore: mypy doesn't like us assigning to methods. | ||||
|         self.store.get_monthly_active_count = Mock(  # type: ignore[assignment] | ||||
|             return_value=make_awaitable(self.lots_of_users) | ||||
|         ) | ||||
|         self.get_failure( | ||||
|  | @ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
|             ResourceLimitError, | ||||
|         ) | ||||
| 
 | ||||
|         self.store.get_monthly_active_count = Mock( | ||||
|         # Type ignore: mypy doesn't like us assigning to methods. | ||||
|         self.store.get_monthly_active_count = Mock(  # type: ignore[assignment] | ||||
|             return_value=make_awaitable(self.hs.config.server.max_mau_value) | ||||
|         ) | ||||
|         self.get_failure( | ||||
|  |  | |||
|  | @ -28,11 +28,12 @@ from typing import ( | |||
|     MutableMapping, | ||||
|     Optional, | ||||
|     Tuple, | ||||
|     Union, | ||||
|     overload, | ||||
| ) | ||||
| from unittest.mock import patch | ||||
| 
 | ||||
| import attr | ||||
| from typing_extensions import Literal | ||||
| 
 | ||||
| from twisted.web.resource import Resource | ||||
| from twisted.web.server import Site | ||||
|  | @ -55,6 +56,32 @@ class RestHelper: | |||
|     site = attr.ib(type=Site) | ||||
|     auth_user_id = attr.ib() | ||||
| 
 | ||||
|     @overload | ||||
|     def create_room_as( | ||||
|         self, | ||||
|         room_creator: Optional[str] = ..., | ||||
|         is_public: Optional[bool] = ..., | ||||
|         room_version: Optional[str] = ..., | ||||
|         tok: Optional[str] = ..., | ||||
|         expect_code: Literal[200] = ..., | ||||
|         extra_content: Optional[Dict] = ..., | ||||
|         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., | ||||
|     ) -> str: | ||||
|         ... | ||||
| 
 | ||||
|     @overload | ||||
|     def create_room_as( | ||||
|         self, | ||||
|         room_creator: Optional[str] = ..., | ||||
|         is_public: Optional[bool] = ..., | ||||
|         room_version: Optional[str] = ..., | ||||
|         tok: Optional[str] = ..., | ||||
|         expect_code: int = ..., | ||||
|         extra_content: Optional[Dict] = ..., | ||||
|         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., | ||||
|     ) -> Optional[str]: | ||||
|         ... | ||||
| 
 | ||||
|     def create_room_as( | ||||
|         self, | ||||
|         room_creator: Optional[str] = None, | ||||
|  | @ -64,7 +91,7 @@ class RestHelper: | |||
|         expect_code: int = 200, | ||||
|         extra_content: Optional[Dict] = None, | ||||
|         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, | ||||
|     ) -> str: | ||||
|     ) -> Optional[str]: | ||||
|         """ | ||||
|         Create a room. | ||||
| 
 | ||||
|  | @ -107,6 +134,8 @@ class RestHelper: | |||
| 
 | ||||
|         if expect_code == 200: | ||||
|             return channel.json_body["room_id"] | ||||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): | ||||
|         self.change_membership( | ||||
|  | @ -176,7 +205,7 @@ class RestHelper: | |||
|         extra_data: Optional[dict] = None, | ||||
|         tok: Optional[str] = None, | ||||
|         expect_code: int = 200, | ||||
|         expect_errcode: str = None, | ||||
|         expect_errcode: Optional[str] = None, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Send a membership state event into a room. | ||||
|  | @ -260,9 +289,7 @@ class RestHelper: | |||
|         txn_id=None, | ||||
|         tok=None, | ||||
|         expect_code=200, | ||||
|         custom_headers: Optional[ | ||||
|             Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] | ||||
|         ] = None, | ||||
|         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, | ||||
|     ): | ||||
|         if txn_id is None: | ||||
|             txn_id = "m%s" % (str(time.time())) | ||||
|  | @ -509,7 +536,7 @@ class RestHelper: | |||
|             went. | ||||
|         """ | ||||
| 
 | ||||
|         cookies = {} | ||||
|         cookies: Dict[str, str] = {} | ||||
| 
 | ||||
|         # if we're doing a ui auth, hit the ui auth redirect endpoint | ||||
|         if ui_auth_session_id: | ||||
|  | @ -631,7 +658,13 @@ class RestHelper: | |||
| 
 | ||||
|         # hit the redirect url again with the right Host header, which should now issue | ||||
|         # a cookie and redirect to the SSO provider. | ||||
|         location = channel.headers.getRawHeaders("Location")[0] | ||||
|         def get_location(channel: FakeChannel) -> str: | ||||
|             location_values = channel.headers.getRawHeaders("Location") | ||||
|             # Keep mypy happy by asserting that location_values is nonempty | ||||
|             assert location_values | ||||
|             return location_values[0] | ||||
| 
 | ||||
|         location = get_location(channel) | ||||
|         parts = urllib.parse.urlsplit(location) | ||||
|         channel = make_request( | ||||
|             self.hs.get_reactor(), | ||||
|  | @ -645,7 +678,7 @@ class RestHelper: | |||
| 
 | ||||
|         assert channel.code == 302 | ||||
|         channel.extract_cookies(cookies) | ||||
|         return channel.headers.getRawHeaders("Location")[0] | ||||
|         return get_location(channel) | ||||
| 
 | ||||
|     def initiate_sso_ui_auth( | ||||
|         self, ui_auth_session_id: str, cookies: MutableMapping[str, str] | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ from typing import ( | |||
|     MutableMapping, | ||||
|     Optional, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
|  | @ -226,7 +227,7 @@ def make_request( | |||
|     path: Union[bytes, str], | ||||
|     content: Union[bytes, str, JsonDict] = b"", | ||||
|     access_token: Optional[str] = None, | ||||
|     request: Request = SynapseRequest, | ||||
|     request: Type[Request] = SynapseRequest, | ||||
|     shorthand: bool = True, | ||||
|     federation_auth_origin: Optional[bytes] = None, | ||||
|     content_is_form: bool = False, | ||||
|  |  | |||
|  | @ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool | |||
| from twisted.test.proto_helpers import MemoryReactor | ||||
| from twisted.trial import unittest | ||||
| from twisted.web.resource import Resource | ||||
| from twisted.web.server import Request | ||||
| 
 | ||||
| from synapse import events | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
|  | @ -95,16 +96,13 @@ def around(target): | |||
|     return _around | ||||
| 
 | ||||
| 
 | ||||
| T = TypeVar("T") | ||||
| 
 | ||||
| 
 | ||||
| class TestCase(unittest.TestCase): | ||||
|     """A subclass of twisted.trial's TestCase which looks for 'loglevel' | ||||
|     attributes on both itself and its individual test methods, to override the | ||||
|     root logger's logging level while that test (case|method) runs.""" | ||||
| 
 | ||||
|     def __init__(self, methodName, *args, **kwargs): | ||||
|         super().__init__(methodName, *args, **kwargs) | ||||
|     def __init__(self, methodName: str): | ||||
|         super().__init__(methodName) | ||||
| 
 | ||||
|         method = getattr(self, methodName) | ||||
| 
 | ||||
|  | @ -220,16 +218,16 @@ class HomeserverTestCase(TestCase): | |||
|     Attributes: | ||||
|         servlets: List of servlet registration function. | ||||
|         user_id (str): The user ID to assume if auth is hijacked. | ||||
|         hijack_auth (bool): Whether to hijack auth to return the user specified | ||||
|         hijack_auth: Whether to hijack auth to return the user specified | ||||
|         in user_id. | ||||
|     """ | ||||
| 
 | ||||
|     hijack_auth = True | ||||
|     needs_threadpool = False | ||||
|     hijack_auth: ClassVar[bool] = True | ||||
|     needs_threadpool: ClassVar[bool] = False | ||||
|     servlets: ClassVar[List[RegisterServletsFunc]] = [] | ||||
| 
 | ||||
|     def __init__(self, methodName, *args, **kwargs): | ||||
|         super().__init__(methodName, *args, **kwargs) | ||||
|     def __init__(self, methodName: str): | ||||
|         super().__init__(methodName) | ||||
| 
 | ||||
|         # see if we have any additional config for this test | ||||
|         method = getattr(self, methodName) | ||||
|  | @ -301,9 +299,10 @@ class HomeserverTestCase(TestCase): | |||
|                         None, | ||||
|                     ) | ||||
| 
 | ||||
|                 self.hs.get_auth().get_user_by_req = get_user_by_req | ||||
|                 self.hs.get_auth().get_user_by_access_token = get_user_by_access_token | ||||
|                 self.hs.get_auth().get_access_token_from_request = Mock( | ||||
|                 # Type ignore: mypy doesn't like us assigning to methods. | ||||
|                 self.hs.get_auth().get_user_by_req = get_user_by_req  # type: ignore[assignment] | ||||
|                 self.hs.get_auth().get_user_by_access_token = get_user_by_access_token  # type: ignore[assignment] | ||||
|                 self.hs.get_auth().get_access_token_from_request = Mock(  # type: ignore[assignment] | ||||
|                     return_value="1234" | ||||
|                 ) | ||||
| 
 | ||||
|  | @ -417,7 +416,7 @@ class HomeserverTestCase(TestCase): | |||
|         path: Union[bytes, str], | ||||
|         content: Union[bytes, str, JsonDict] = b"", | ||||
|         access_token: Optional[str] = None, | ||||
|         request: Type[T] = SynapseRequest, | ||||
|         request: Type[Request] = SynapseRequest, | ||||
|         shorthand: bool = True, | ||||
|         federation_auth_origin: Optional[bytes] = None, | ||||
|         content_is_form: bool = False, | ||||
|  | @ -596,7 +595,7 @@ class HomeserverTestCase(TestCase): | |||
|             nonce_str += b"\x00notadmin" | ||||
| 
 | ||||
|         want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) | ||||
|         want_mac = want_mac.hexdigest() | ||||
|         want_mac_digest = want_mac.hexdigest() | ||||
| 
 | ||||
|         body = json.dumps( | ||||
|             { | ||||
|  | @ -605,7 +604,7 @@ class HomeserverTestCase(TestCase): | |||
|                 "displayname": displayname, | ||||
|                 "password": password, | ||||
|                 "admin": admin, | ||||
|                 "mac": want_mac, | ||||
|                 "mac": want_mac_digest, | ||||
|                 "inhibit_login": True, | ||||
|             } | ||||
|         ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 David Robertson
						David Robertson