Optionally include account validity in MSC3720 account status responses (#12266)
							parent
							
								
									e78d4f61fc
								
							
						
					
					
						commit
						5436b014f4
					
				|  | @ -0,0 +1 @@ | |||
| Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses. | ||||
|  | @ -676,6 +676,10 @@ class ServerConfig(Config): | |||
|         ): | ||||
|             raise ConfigError("'custom_template_directory' must be a string") | ||||
| 
 | ||||
|         self.use_account_validity_in_account_status: bool = ( | ||||
|             config.get("use_account_validity_in_account_status") or False | ||||
|         ) | ||||
| 
 | ||||
|     def has_tls_listener(self) -> bool: | ||||
|         return any(listener.tls for listener in self.listeners) | ||||
| 
 | ||||
|  |  | |||
|  | @ -26,6 +26,10 @@ class AccountHandler: | |||
|         self._main_store = hs.get_datastores().main | ||||
|         self._is_mine = hs.is_mine | ||||
|         self._federation_client = hs.get_federation_client() | ||||
|         self._use_account_validity_in_account_status = ( | ||||
|             hs.config.server.use_account_validity_in_account_status | ||||
|         ) | ||||
|         self._account_validity_handler = hs.get_account_validity_handler() | ||||
| 
 | ||||
|     async def get_account_statuses( | ||||
|         self, | ||||
|  | @ -106,6 +110,13 @@ class AccountHandler: | |||
|                 "deactivated": userinfo.is_deactivated, | ||||
|             } | ||||
| 
 | ||||
|             if self._use_account_validity_in_account_status: | ||||
|                 status[ | ||||
|                     "org.matrix.expired" | ||||
|                 ] = await self._account_validity_handler.is_user_expired( | ||||
|                     user_id.to_string() | ||||
|                 ) | ||||
| 
 | ||||
|         return status | ||||
| 
 | ||||
|     async def _get_remote_account_statuses( | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ from synapse.rest import admin | |||
| from synapse.rest.client import account, login, register, room | ||||
| from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource | ||||
| from synapse.server import HomeServer | ||||
| from synapse.types import JsonDict | ||||
| from synapse.types import JsonDict, UserID | ||||
| from synapse.util import Clock | ||||
| 
 | ||||
| from tests import unittest | ||||
|  | @ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): | |||
|             expected_failures=[users[2]], | ||||
|         ) | ||||
| 
 | ||||
|     @unittest.override_config( | ||||
|         { | ||||
|             "use_account_validity_in_account_status": True, | ||||
|         } | ||||
|     ) | ||||
|     def test_no_account_validity(self) -> None: | ||||
|         """Tests that if we decide to include account validity in the response but no | ||||
|         account validity 'is_user_expired' callback is provided, we default to marking all | ||||
|         users as not expired. | ||||
|         """ | ||||
|         user = self.register_user("someuser", "password") | ||||
| 
 | ||||
|         self._test_status( | ||||
|             users=[user], | ||||
|             expected_statuses={ | ||||
|                 user: { | ||||
|                     "exists": True, | ||||
|                     "deactivated": False, | ||||
|                     "org.matrix.expired": False, | ||||
|                 }, | ||||
|             }, | ||||
|             expected_failures=[], | ||||
|         ) | ||||
| 
 | ||||
|     @unittest.override_config( | ||||
|         { | ||||
|             "use_account_validity_in_account_status": True, | ||||
|         } | ||||
|     ) | ||||
|     def test_account_validity_expired(self) -> None: | ||||
|         """Test that if we decide to include account validity in the response and the user | ||||
|         is expired, we return the correct info. | ||||
|         """ | ||||
|         user = self.register_user("someuser", "password") | ||||
| 
 | ||||
|         async def is_expired(user_id: str) -> bool: | ||||
|             # We can't blindly say everyone is expired, otherwise the request to get the | ||||
|             # account status will fail. | ||||
|             return UserID.from_string(user_id).localpart == "someuser" | ||||
| 
 | ||||
|         self.hs.get_account_validity_handler()._is_user_expired_callbacks.append( | ||||
|             is_expired | ||||
|         ) | ||||
| 
 | ||||
|         self._test_status( | ||||
|             users=[user], | ||||
|             expected_statuses={ | ||||
|                 user: { | ||||
|                     "exists": True, | ||||
|                     "deactivated": False, | ||||
|                     "org.matrix.expired": True, | ||||
|                 }, | ||||
|             }, | ||||
|             expected_failures=[], | ||||
|         ) | ||||
| 
 | ||||
|     def _test_status( | ||||
|         self, | ||||
|         users: Optional[List[str]], | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Brendan Abolivier
						Brendan Abolivier