diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc new file mode 100644 index 0000000000..cbba4fa826 --- /dev/null +++ b/changelog.d/7951.misc @@ -0,0 +1 @@ +Convert groups and visibility code to async / await. diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dab13c243f..e674bf44a2 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -41,8 +41,6 @@ from typing import Tuple from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -72,8 +70,9 @@ class GroupAttestationSigning(object): self.server_name = hs.hostname self.signing_key = hs.signing_key - @defer.inlineCallbacks - def verify_attestation(self, attestation, group_id, user_id, server_name=None): + async def verify_attestation( + self, attestation, group_id, user_id, server_name=None + ): """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -102,7 +101,7 @@ class GroupAttestationSigning(object): if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) @@ -142,8 +141,7 @@ class GroupAttestionRenewer(object): self._start_renew_attestations, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation(self, group_id, user_id, content): """When a remote updates an attestation """ attestation = content["attestation"] @@ -151,11 +149,11 @@ class GroupAttestionRenewer(object): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): raise SynapseError(400, "Neither user not group are on this server") - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, user_id=user_id, group_id=group_id ) - yield self.store.update_remote_attestion(group_id, user_id, attestation) + await self.store.update_remote_attestion(group_id, user_id, attestation) return {} @@ -172,8 +170,7 @@ class GroupAttestionRenewer(object): now + UPDATE_ATTESTATION_TIME_MS ) - @defer.inlineCallbacks - def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]): group_id, user_id = group_user try: if not self.is_mine_id(group_id): @@ -186,16 +183,16 @@ class GroupAttestionRenewer(object): user_id, group_id, ) - yield self.store.remove_attestation_renewal(group_id, user_id) + await self.store.remove_attestation_renewal(group_id, user_id) return attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( + await self.transport_client.renew_group_attestation( destination, group_id, user_id, content={"attestation": attestation} ) - yield self.store.update_attestation_renewal( + await self.store.update_attestation_renewal( group_id, user_id, attestation ) except (RequestSendFailed, HttpResponseException) as e: diff --git a/synapse/visibility.py b/synapse/visibility.py index 0f042c5696..e3da7744d2 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,8 +16,6 @@ import logging import operator -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage @@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = ( ) -@defer.inlineCallbacks -def filter_events_for_client( +async def filter_events_for_client( storage: Storage, user_id, events, @@ -67,19 +64,19 @@ def filter_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - Deferred[list[synapse.events.EventBase]] + list[synapse.events.EventBase] """ # Filter out events that have been soft failed so that we don't relay them # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield storage.state.get_state_for_events( + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -90,7 +87,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -99,7 +96,7 @@ def filter_events_for_client( for room_id in room_ids: retention_policies[ room_id - ] = yield storage.main.get_retention_policy_for_room(room_id) + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -254,8 +251,7 @@ def filter_events_for_client( return list(filtered_events) -@defer.inlineCallbacks -def filter_events_for_server( +async def filter_events_for_server( storage: Storage, server_name, events, @@ -277,7 +273,7 @@ def filter_events_for_server( backfill or not. Returns - Deferred[list[FrozenEvent]] + list[FrozenEvent] """ def is_sender_erased(event, erased_senders): @@ -321,7 +317,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -339,14 +335,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield storage.main.get_events(visibility_ids) + event_map = await storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in event_map.values() ) if not check_history_visibility_only: - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -375,7 +371,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -405,7 +401,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield storage.main.get_events( + event_map = await storage.main.get_events( [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] ) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index b371efc0df..a7a36174ea 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start)