Merge pull request #6196 from matrix-org/erikj/await
Move rest/admin to use async/await.pull/6217/head
						commit
						d98029ea89
					
				|  | @ -0,0 +1 @@ | |||
| Port synapse.rest.admin module to use async/await. | ||||
|  | @ -23,8 +23,6 @@ import re | |||
| from six import text_type | ||||
| from six.moves import http_client | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import synapse | ||||
| from synapse.api.constants import Membership, UserTypes | ||||
| from synapse.api.errors import Codes, NotFoundError, SynapseError | ||||
|  | @ -46,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet | |||
| from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet | ||||
| from synapse.rest.admin.users import UserAdminServlet | ||||
| from synapse.types import UserID, create_requester | ||||
| from synapse.util.async_helpers import maybe_awaitable | ||||
| from synapse.util.versionstring import get_version_string | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -59,15 +58,14 @@ class UsersRestServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id): | ||||
|     async def on_GET(self, request, user_id): | ||||
|         target_user = UserID.from_string(user_id) | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         if not self.hs.is_mine(target_user): | ||||
|             raise SynapseError(400, "Can only users a local user") | ||||
| 
 | ||||
|         ret = yield self.handlers.admin_handler.get_users() | ||||
|         ret = await self.handlers.admin_handler.get_users() | ||||
| 
 | ||||
|         return 200, ret | ||||
| 
 | ||||
|  | @ -122,8 +120,7 @@ class UserRegisterServlet(RestServlet): | |||
|         self.nonces[nonce] = int(self.reactor.seconds()) | ||||
|         return 200, {"nonce": nonce} | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|     async def on_POST(self, request): | ||||
|         self._clear_old_nonces() | ||||
| 
 | ||||
|         if not self.hs.config.registration_shared_secret: | ||||
|  | @ -204,14 +201,14 @@ class UserRegisterServlet(RestServlet): | |||
| 
 | ||||
|         register = RegisterRestServlet(self.hs) | ||||
| 
 | ||||
|         user_id = yield register.registration_handler.register_user( | ||||
|         user_id = await register.registration_handler.register_user( | ||||
|             localpart=body["username"].lower(), | ||||
|             password=body["password"], | ||||
|             admin=bool(admin), | ||||
|             user_type=user_type, | ||||
|         ) | ||||
| 
 | ||||
|         result = yield register._create_registration_details(user_id, body) | ||||
|         result = await register._create_registration_details(user_id, body) | ||||
|         return 200, result | ||||
| 
 | ||||
| 
 | ||||
|  | @ -223,19 +220,18 @@ class WhoisRestServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id): | ||||
|     async def on_GET(self, request, user_id): | ||||
|         target_user = UserID.from_string(user_id) | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         auth_user = requester.user | ||||
| 
 | ||||
|         if target_user != auth_user: | ||||
|             yield assert_user_is_admin(self.auth, auth_user) | ||||
|             await assert_user_is_admin(self.auth, auth_user) | ||||
| 
 | ||||
|         if not self.hs.is_mine(target_user): | ||||
|             raise SynapseError(400, "Can only whois a local user") | ||||
| 
 | ||||
|         ret = yield self.handlers.admin_handler.get_whois(target_user) | ||||
|         ret = await self.handlers.admin_handler.get_whois(target_user) | ||||
| 
 | ||||
|         return 200, ret | ||||
| 
 | ||||
|  | @ -255,9 +251,8 @@ class PurgeHistoryRestServlet(RestServlet): | |||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, room_id, event_id): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_POST(self, request, room_id, event_id): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         body = parse_json_object_from_request(request, allow_empty_body=True) | ||||
| 
 | ||||
|  | @ -270,12 +265,12 @@ class PurgeHistoryRestServlet(RestServlet): | |||
|             event_id = body.get("purge_up_to_event_id") | ||||
| 
 | ||||
|         if event_id is not None: | ||||
|             event = yield self.store.get_event(event_id) | ||||
|             event = await self.store.get_event(event_id) | ||||
| 
 | ||||
|             if event.room_id != room_id: | ||||
|                 raise SynapseError(400, "Event is for wrong room.") | ||||
| 
 | ||||
|             token = yield self.store.get_topological_token_for_event(event_id) | ||||
|             token = await self.store.get_topological_token_for_event(event_id) | ||||
| 
 | ||||
|             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) | ||||
|         elif "purge_up_to_ts" in body: | ||||
|  | @ -285,12 +280,10 @@ class PurgeHistoryRestServlet(RestServlet): | |||
|                     400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON | ||||
|                 ) | ||||
| 
 | ||||
|             stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) | ||||
|             stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) | ||||
| 
 | ||||
|             r = ( | ||||
|                 yield self.store.get_room_event_after_stream_ordering( | ||||
|                     room_id, stream_ordering | ||||
|                 ) | ||||
|             r = await self.store.get_room_event_after_stream_ordering( | ||||
|                 room_id, stream_ordering | ||||
|             ) | ||||
|             if not r: | ||||
|                 logger.warn( | ||||
|  | @ -318,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet): | |||
|                 errcode=Codes.BAD_JSON, | ||||
|             ) | ||||
| 
 | ||||
|         purge_id = yield self.pagination_handler.start_purge_history( | ||||
|         purge_id = self.pagination_handler.start_purge_history( | ||||
|             room_id, token, delete_local_events=delete_local_events | ||||
|         ) | ||||
| 
 | ||||
|  | @ -339,9 +332,8 @@ class PurgeHistoryStatusRestServlet(RestServlet): | |||
|         self.pagination_handler = hs.get_pagination_handler() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, purge_id): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_GET(self, request, purge_id): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         purge_status = self.pagination_handler.get_purge_status(purge_id) | ||||
|         if purge_status is None: | ||||
|  | @ -357,9 +349,8 @@ class DeactivateAccountRestServlet(RestServlet): | |||
|         self._deactivate_account_handler = hs.get_deactivate_account_handler() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, target_user_id): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_POST(self, request, target_user_id): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
|         body = parse_json_object_from_request(request, allow_empty_body=True) | ||||
|         erase = body.get("erase", False) | ||||
|         if not isinstance(erase, bool): | ||||
|  | @ -371,7 +362,7 @@ class DeactivateAccountRestServlet(RestServlet): | |||
| 
 | ||||
|         UserID.from_string(target_user_id) | ||||
| 
 | ||||
|         result = yield self._deactivate_account_handler.deactivate_account( | ||||
|         result = await self._deactivate_account_handler.deactivate_account( | ||||
|             target_user_id, erase | ||||
|         ) | ||||
|         if result: | ||||
|  | @ -405,10 +396,9 @@ class ShutdownRoomRestServlet(RestServlet): | |||
|         self.room_member_handler = hs.get_room_member_handler() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, room_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield assert_user_is_admin(self.auth, requester.user) | ||||
|     async def on_POST(self, request, room_id): | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         await assert_user_is_admin(self.auth, requester.user) | ||||
| 
 | ||||
|         content = parse_json_object_from_request(request) | ||||
|         assert_params_in_dict(content, ["new_room_user_id"]) | ||||
|  | @ -419,7 +409,7 @@ class ShutdownRoomRestServlet(RestServlet): | |||
|         message = content.get("message", self.DEFAULT_MESSAGE) | ||||
|         room_name = content.get("room_name", "Content Violation Notification") | ||||
| 
 | ||||
|         info = yield self._room_creation_handler.create_room( | ||||
|         info = await self._room_creation_handler.create_room( | ||||
|             room_creator_requester, | ||||
|             config={ | ||||
|                 "preset": "public_chat", | ||||
|  | @ -438,9 +428,9 @@ class ShutdownRoomRestServlet(RestServlet): | |||
| 
 | ||||
|         # This will work even if the room is already blocked, but that is | ||||
|         # desirable in case the first attempt at blocking the room failed below. | ||||
|         yield self.store.block_room(room_id, requester_user_id) | ||||
|         await self.store.block_room(room_id, requester_user_id) | ||||
| 
 | ||||
|         users = yield self.state.get_current_users_in_room(room_id) | ||||
|         users = await self.state.get_current_users_in_room(room_id) | ||||
|         kicked_users = [] | ||||
|         failed_to_kick_users = [] | ||||
|         for user_id in users: | ||||
|  | @ -451,7 +441,7 @@ class ShutdownRoomRestServlet(RestServlet): | |||
| 
 | ||||
|             try: | ||||
|                 target_requester = create_requester(user_id) | ||||
|                 yield self.room_member_handler.update_membership( | ||||
|                 await self.room_member_handler.update_membership( | ||||
|                     requester=target_requester, | ||||
|                     target=target_requester.user, | ||||
|                     room_id=room_id, | ||||
|  | @ -461,9 +451,9 @@ class ShutdownRoomRestServlet(RestServlet): | |||
|                     require_consent=False, | ||||
|                 ) | ||||
| 
 | ||||
|                 yield self.room_member_handler.forget(target_requester.user, room_id) | ||||
|                 await self.room_member_handler.forget(target_requester.user, room_id) | ||||
| 
 | ||||
|                 yield self.room_member_handler.update_membership( | ||||
|                 await self.room_member_handler.update_membership( | ||||
|                     requester=target_requester, | ||||
|                     target=target_requester.user, | ||||
|                     room_id=new_room_id, | ||||
|  | @ -480,7 +470,7 @@ class ShutdownRoomRestServlet(RestServlet): | |||
|                 ) | ||||
|                 failed_to_kick_users.append(user_id) | ||||
| 
 | ||||
|         yield self.event_creation_handler.create_and_send_nonmember_event( | ||||
|         await self.event_creation_handler.create_and_send_nonmember_event( | ||||
|             room_creator_requester, | ||||
|             { | ||||
|                 "type": "m.room.message", | ||||
|  | @ -491,9 +481,11 @@ class ShutdownRoomRestServlet(RestServlet): | |||
|             ratelimit=False, | ||||
|         ) | ||||
| 
 | ||||
|         aliases_for_room = yield self.store.get_aliases_for_room(room_id) | ||||
|         aliases_for_room = await maybe_awaitable( | ||||
|             self.store.get_aliases_for_room(room_id) | ||||
|         ) | ||||
| 
 | ||||
|         yield self.store.update_aliases_for_room( | ||||
|         await self.store.update_aliases_for_room( | ||||
|             room_id, new_room_id, requester_user_id | ||||
|         ) | ||||
| 
 | ||||
|  | @ -532,13 +524,12 @@ class ResetPasswordRestServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self._set_password_handler = hs.get_set_password_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, target_user_id): | ||||
|     async def on_POST(self, request, target_user_id): | ||||
|         """Post request to allow an administrator reset password for a user. | ||||
|         This needs user to have administrator access in Synapse. | ||||
|         """ | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield assert_user_is_admin(self.auth, requester.user) | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         await assert_user_is_admin(self.auth, requester.user) | ||||
| 
 | ||||
|         UserID.from_string(target_user_id) | ||||
| 
 | ||||
|  | @ -546,7 +537,7 @@ class ResetPasswordRestServlet(RestServlet): | |||
|         assert_params_in_dict(params, ["new_password"]) | ||||
|         new_password = params["new_password"] | ||||
| 
 | ||||
|         yield self._set_password_handler.set_password( | ||||
|         await self._set_password_handler.set_password( | ||||
|             target_user_id, new_password, requester | ||||
|         ) | ||||
|         return 200, {} | ||||
|  | @ -572,12 +563,11 @@ class GetUsersPaginatedRestServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, target_user_id): | ||||
|     async def on_GET(self, request, target_user_id): | ||||
|         """Get request to get specific number of users from Synapse. | ||||
|         This needs user to have administrator access in Synapse. | ||||
|         """ | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         target_user = UserID.from_string(target_user_id) | ||||
| 
 | ||||
|  | @ -590,11 +580,10 @@ class GetUsersPaginatedRestServlet(RestServlet): | |||
| 
 | ||||
|         logger.info("limit: %s, start: %s", limit, start) | ||||
| 
 | ||||
|         ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) | ||||
|         ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) | ||||
|         return 200, ret | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, target_user_id): | ||||
|     async def on_POST(self, request, target_user_id): | ||||
|         """Post request to get specific number of users from Synapse.. | ||||
|         This needs user to have administrator access in Synapse. | ||||
|         Example: | ||||
|  | @ -608,7 +597,7 @@ class GetUsersPaginatedRestServlet(RestServlet): | |||
|         Returns: | ||||
|             200 OK with json object {list[dict[str, Any]], count} or empty object. | ||||
|         """ | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
|         UserID.from_string(target_user_id) | ||||
| 
 | ||||
|         order = "name"  # order by name in user table | ||||
|  | @ -618,7 +607,7 @@ class GetUsersPaginatedRestServlet(RestServlet): | |||
|         start = params["start"] | ||||
|         logger.info("limit: %s, start: %s", limit, start) | ||||
| 
 | ||||
|         ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) | ||||
|         ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) | ||||
|         return 200, ret | ||||
| 
 | ||||
| 
 | ||||
|  | @ -641,13 +630,12 @@ class SearchUsersRestServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, target_user_id): | ||||
|     async def on_GET(self, request, target_user_id): | ||||
|         """Get request to search user table for specific users according to | ||||
|         search term. | ||||
|         This needs user to have a administrator access in Synapse. | ||||
|         """ | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         target_user = UserID.from_string(target_user_id) | ||||
| 
 | ||||
|  | @ -661,7 +649,7 @@ class SearchUsersRestServlet(RestServlet): | |||
|         term = parse_string(request, "term", required=True) | ||||
|         logger.info("term: %s ", term) | ||||
| 
 | ||||
|         ret = yield self.handlers.admin_handler.search_users(term) | ||||
|         ret = await self.handlers.admin_handler.search_users(term) | ||||
|         return 200, ret | ||||
| 
 | ||||
| 
 | ||||
|  | @ -676,15 +664,14 @@ class DeleteGroupAdminRestServlet(RestServlet): | |||
|         self.is_mine_id = hs.is_mine_id | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, group_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield assert_user_is_admin(self.auth, requester.user) | ||||
|     async def on_POST(self, request, group_id): | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         await assert_user_is_admin(self.auth, requester.user) | ||||
| 
 | ||||
|         if not self.is_mine_id(group_id): | ||||
|             raise SynapseError(400, "Can only delete local groups") | ||||
| 
 | ||||
|         yield self.group_server.delete_group(group_id, requester.user.to_string()) | ||||
|         await self.group_server.delete_group(group_id, requester.user.to_string()) | ||||
|         return 200, {} | ||||
| 
 | ||||
| 
 | ||||
|  | @ -700,16 +687,15 @@ class AccountValidityRenewServlet(RestServlet): | |||
|         self.account_activity_handler = hs.get_account_validity_handler() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_POST(self, request): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         if "user_id" not in body: | ||||
|             raise SynapseError(400, "Missing property 'user_id' in the request body") | ||||
| 
 | ||||
|         expiration_ts = yield self.account_activity_handler.renew_account_for_user( | ||||
|         expiration_ts = await self.account_activity_handler.renew_account_for_user( | ||||
|             body["user_id"], | ||||
|             body.get("expiration_ts"), | ||||
|             not body.get("enable_renewal_emails", True), | ||||
|  |  | |||
|  | @ -15,8 +15,6 @@ | |||
| 
 | ||||
| import re | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import AuthError | ||||
| 
 | ||||
| 
 | ||||
|  | @ -42,8 +40,7 @@ def historical_admin_path_patterns(path_regex): | |||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def assert_requester_is_admin(auth, request): | ||||
| async def assert_requester_is_admin(auth, request): | ||||
|     """Verify that the requester is an admin user | ||||
| 
 | ||||
|     WARNING: MAKE SURE YOU YIELD ON THE RESULT! | ||||
|  | @ -58,12 +55,11 @@ def assert_requester_is_admin(auth, request): | |||
|     Raises: | ||||
|         AuthError if the requester is not an admin | ||||
|     """ | ||||
|     requester = yield auth.get_user_by_req(request) | ||||
|     yield assert_user_is_admin(auth, requester.user) | ||||
|     requester = await auth.get_user_by_req(request) | ||||
|     await assert_user_is_admin(auth, requester.user) | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def assert_user_is_admin(auth, user_id): | ||||
| async def assert_user_is_admin(auth, user_id): | ||||
|     """Verify that the given user is an admin user | ||||
| 
 | ||||
|     WARNING: MAKE SURE YOU YIELD ON THE RESULT! | ||||
|  | @ -79,6 +75,6 @@ def assert_user_is_admin(auth, user_id): | |||
|         AuthError if the user is not an admin | ||||
|     """ | ||||
| 
 | ||||
|     is_admin = yield auth.is_server_admin(user_id) | ||||
|     is_admin = await auth.is_server_admin(user_id) | ||||
|     if not is_admin: | ||||
|         raise AuthError(403, "You are not a server admin") | ||||
|  |  | |||
|  | @ -16,8 +16,6 @@ | |||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import AuthError | ||||
| from synapse.http.servlet import RestServlet, parse_integer | ||||
| from synapse.rest.admin._base import ( | ||||
|  | @ -40,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet): | |||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, room_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield assert_user_is_admin(self.auth, requester.user) | ||||
|     async def on_POST(self, request, room_id): | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         await assert_user_is_admin(self.auth, requester.user) | ||||
| 
 | ||||
|         num_quarantined = yield self.store.quarantine_media_ids_in_room( | ||||
|         num_quarantined = await self.store.quarantine_media_ids_in_room( | ||||
|             room_id, requester.user.to_string() | ||||
|         ) | ||||
| 
 | ||||
|  | @ -62,14 +59,13 @@ class ListMediaInRoom(RestServlet): | |||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, room_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         is_admin = yield self.auth.is_server_admin(requester.user) | ||||
|     async def on_GET(self, request, room_id): | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         is_admin = await self.auth.is_server_admin(requester.user) | ||||
|         if not is_admin: | ||||
|             raise AuthError(403, "You are not a server admin") | ||||
| 
 | ||||
|         local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) | ||||
|         local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) | ||||
| 
 | ||||
|         return 200, {"local": local_mxcs, "remote": remote_mxcs} | ||||
| 
 | ||||
|  | @ -81,14 +77,13 @@ class PurgeMediaCacheRestServlet(RestServlet): | |||
|         self.media_repository = hs.get_media_repository() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_POST(self, request): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         before_ts = parse_integer(request, "before_ts", required=True) | ||||
|         logger.info("before_ts: %r", before_ts) | ||||
| 
 | ||||
|         ret = yield self.media_repository.delete_old_remote_media(before_ts) | ||||
|         ret = await self.media_repository.delete_old_remote_media(before_ts) | ||||
| 
 | ||||
|         return 200, ret | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,8 +14,6 @@ | |||
| # limitations under the License. | ||||
| import re | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.http.servlet import ( | ||||
|  | @ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet): | |||
|             self.__class__.__name__, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, txn_id=None): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_POST(self, request, txn_id=None): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
|         body = parse_json_object_from_request(request) | ||||
|         assert_params_in_dict(body, ("user_id", "content")) | ||||
|         event_type = body.get("type", EventTypes.Message) | ||||
|  | @ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet): | |||
|         if not self.hs.is_mine_id(user_id): | ||||
|             raise SynapseError(400, "Server notices can only be sent to local users") | ||||
| 
 | ||||
|         event = yield self.snm.send_notice( | ||||
|         event = await self.snm.send_notice( | ||||
|             user_id=body["user_id"], | ||||
|             type=event_type, | ||||
|             state_key=state_key, | ||||
|  |  | |||
|  | @ -14,8 +14,6 @@ | |||
| # limitations under the License. | ||||
| import re | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.http.servlet import ( | ||||
|     RestServlet, | ||||
|  | @ -59,24 +57,22 @@ class UserAdminServlet(RestServlet): | |||
|         self.auth = hs.get_auth() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id): | ||||
|         yield assert_requester_is_admin(self.auth, request) | ||||
|     async def on_GET(self, request, user_id): | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
| 
 | ||||
|         target_user = UserID.from_string(user_id) | ||||
| 
 | ||||
|         if not self.hs.is_mine(target_user): | ||||
|             raise SynapseError(400, "Only local users can be admins of this homeserver") | ||||
| 
 | ||||
|         is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user) | ||||
|         is_admin = await self.handlers.admin_handler.get_user_server_admin(target_user) | ||||
|         is_admin = bool(is_admin) | ||||
| 
 | ||||
|         return 200, {"admin": is_admin} | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, request, user_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         yield assert_user_is_admin(self.auth, requester.user) | ||||
|     async def on_PUT(self, request, user_id): | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         await assert_user_is_admin(self.auth, requester.user) | ||||
|         auth_user = requester.user | ||||
| 
 | ||||
|         target_user = UserID.from_string(user_id) | ||||
|  | @ -93,7 +89,7 @@ class UserAdminServlet(RestServlet): | |||
|         if target_user == auth_user and not set_admin_to: | ||||
|             raise SynapseError(400, "You may not demote yourself.") | ||||
| 
 | ||||
|         yield self.handlers.admin_handler.set_user_server_admin( | ||||
|         await self.handlers.admin_handler.set_user_server_admin( | ||||
|             target_user, set_admin_to | ||||
|         ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union | |||
| 
 | ||||
| from six.moves import range | ||||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.internet.defer import CancelledError | ||||
| from twisted.python import failure | ||||
|  | @ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): | |||
|     deferred.addCallbacks(success_cb, failure_cb) | ||||
| 
 | ||||
|     return new_d | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True) | ||||
| class DoneAwaitable(object): | ||||
|     """Simple awaitable that returns the provided value. | ||||
|     """ | ||||
| 
 | ||||
|     value = attr.ib() | ||||
| 
 | ||||
|     def __await__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def __iter__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def __next__(self): | ||||
|         raise StopIteration(self.value) | ||||
| 
 | ||||
| 
 | ||||
| def maybe_awaitable(value): | ||||
|     """Convert a value to an awaitable if not already an awaitable. | ||||
|     """ | ||||
| 
 | ||||
|     if hasattr(value, "__await__"): | ||||
|         return value | ||||
| 
 | ||||
|     return DoneAwaitable(value) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston