Add type hints for account validity handler (#8620)
This also fixes a bug by fixing handling of an account which doesn't expire.pull/8664/head
							parent
							
								
									66e6801c3e
								
							
						
					
					
						commit
						10f45d85bb
					
				|  | @ -0,0 +1 @@ | |||
| Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error. | ||||
							
								
								
									
										1
									
								
								mypy.ini
								
								
								
								
							
							
						
						
									
										1
									
								
								mypy.ini
								
								
								
								
							|  | @ -17,6 +17,7 @@ files = | |||
|   synapse/federation, | ||||
|   synapse/handlers/_base.py, | ||||
|   synapse/handlers/account_data.py, | ||||
|   synapse/handlers/account_validity.py, | ||||
|   synapse/handlers/appservice.py, | ||||
|   synapse/handlers/auth.py, | ||||
|   synapse/handlers/cas_handler.py, | ||||
|  |  | |||
|  | @ -18,19 +18,22 @@ import email.utils | |||
| import logging | ||||
| from email.mime.multipart import MIMEMultipart | ||||
| from email.mime.text import MIMEText | ||||
| from typing import List | ||||
| from typing import TYPE_CHECKING, List | ||||
| 
 | ||||
| from synapse.api.errors import StoreError | ||||
| from synapse.api.errors import StoreError, SynapseError | ||||
| from synapse.logging.context import make_deferred_yieldable | ||||
| from synapse.metrics.background_process_metrics import wrap_as_background_process | ||||
| from synapse.types import UserID | ||||
| from synapse.util import stringutils | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from synapse.app.homeserver import HomeServer | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class AccountValidityHandler: | ||||
|     def __init__(self, hs): | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         self.hs = hs | ||||
|         self.config = hs.config | ||||
|         self.store = self.hs.get_datastore() | ||||
|  | @ -67,7 +70,7 @@ class AccountValidityHandler: | |||
|                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) | ||||
| 
 | ||||
|     @wrap_as_background_process("send_renewals") | ||||
|     async def _send_renewal_emails(self): | ||||
|     async def _send_renewal_emails(self) -> None: | ||||
|         """Gets the list of users whose account is expiring in the amount of time | ||||
|         configured in the ``renew_at`` parameter from the ``account_validity`` | ||||
|         configuration, and sends renewal emails to all of these users as long as they | ||||
|  | @ -81,11 +84,25 @@ class AccountValidityHandler: | |||
|                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] | ||||
|                 ) | ||||
| 
 | ||||
|     async def send_renewal_email_to_user(self, user_id: str): | ||||
|     async def send_renewal_email_to_user(self, user_id: str) -> None: | ||||
|         """ | ||||
|         Send a renewal email for a specific user. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id: The user ID to send a renewal email for. | ||||
| 
 | ||||
|         Raises: | ||||
|             SynapseError if the user is not set to renew. | ||||
|         """ | ||||
|         expiration_ts = await self.store.get_expiration_ts_for_user(user_id) | ||||
| 
 | ||||
|         # If this user isn't set to be expired, raise an error. | ||||
|         if expiration_ts is None: | ||||
|             raise SynapseError(400, "User has no expiration time: %s" % (user_id,)) | ||||
| 
 | ||||
|         await self._send_renewal_email(user_id, expiration_ts) | ||||
| 
 | ||||
|     async def _send_renewal_email(self, user_id: str, expiration_ts: int): | ||||
|     async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None: | ||||
|         """Sends out a renewal email to every email address attached to the given user | ||||
|         with a unique link allowing them to renew their account. | ||||
| 
 | ||||
|  |  | |||
|  | @ -131,7 +131,7 @@ class ProfileHandler(BaseHandler): | |||
|             profile = await self.store.get_from_remote_profile_cache(user_id) | ||||
|             return profile or {} | ||||
| 
 | ||||
|     async def get_displayname(self, target_user: UserID) -> str: | ||||
|     async def get_displayname(self, target_user: UserID) -> Optional[str]: | ||||
|         if self.hs.is_mine(target_user): | ||||
|             try: | ||||
|                 displayname = await self.store.get_profile_displayname( | ||||
|  | @ -218,7 +218,7 @@ class ProfileHandler(BaseHandler): | |||
| 
 | ||||
|         await self._update_join_states(requester, target_user) | ||||
| 
 | ||||
|     async def get_avatar_url(self, target_user: UserID) -> str: | ||||
|     async def get_avatar_url(self, target_user: UserID) -> Optional[str]: | ||||
|         if self.hs.is_mine(target_user): | ||||
|             try: | ||||
|                 avatar_url = await self.store.get_profile_avatar_url( | ||||
|  |  | |||
|  | @ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
|             avatar_url=profile["avatar_url"], display_name=profile["displayname"] | ||||
|         ) | ||||
| 
 | ||||
|     async def get_profile_displayname(self, user_localpart: str) -> str: | ||||
|     async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: | ||||
|         return await self.db_pool.simple_select_one_onecol( | ||||
|             table="profiles", | ||||
|             keyvalues={"user_id": user_localpart}, | ||||
|  | @ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore): | |||
|             desc="get_profile_displayname", | ||||
|         ) | ||||
| 
 | ||||
|     async def get_profile_avatar_url(self, user_localpart: str) -> str: | ||||
|     async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: | ||||
|         return await self.db_pool.simple_select_one_onecol( | ||||
|             table="profiles", | ||||
|             keyvalues={"user_id": user_localpart}, | ||||
|  |  | |||
|  | @ -240,13 +240,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
|             desc="get_renewal_token_for_user", | ||||
|         ) | ||||
| 
 | ||||
|     async def get_users_expiring_soon(self) -> List[Dict[str, int]]: | ||||
|     async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: | ||||
|         """Selects users whose account will expire in the [now, now + renew_at] time | ||||
|         window (see configuration for account_validity for information on what renew_at | ||||
|         refers to). | ||||
| 
 | ||||
|         Returns: | ||||
|             A list of dictionaries mapping user ID to expiration time (in milliseconds). | ||||
|             A list of dictionaries, each with a user ID and expiration time (in milliseconds). | ||||
|         """ | ||||
| 
 | ||||
|         def select_users_txn(txn, now_ms, renew_at): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke