Cancel the processing of key query requests when they time out. (#13680)

pull/13738/head
reivilibre 2022-09-07 11:03:32 +00:00 committed by GitHub
parent c2fe48a6ff
commit d3d9ca156e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 110 additions and 20 deletions

View File

@ -0,0 +1 @@
Cancel the processing of key query requests when they time out.

View File

@ -38,6 +38,7 @@ from synapse.logging.opentracing import (
trace, trace,
) )
from synapse.types import Requester, create_requester from synapse.types import Requester, create_requester
from synapse.util.cancellation import cancellable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -118,6 +119,7 @@ class Auth:
errcode=Codes.NOT_JOINED, errcode=Codes.NOT_JOINED,
) )
@cancellable
async def get_user_by_req( async def get_user_by_req(
self, self,
request: SynapseRequest, request: SynapseRequest,
@ -166,6 +168,7 @@ class Auth:
parent_span.set_tag("appservice_id", requester.app_service.id) parent_span.set_tag("appservice_id", requester.app_service.id)
return requester return requester
@cancellable
async def _wrapped_get_user_by_req( async def _wrapped_get_user_by_req(
self, self,
request: SynapseRequest, request: SynapseRequest,
@ -281,6 +284,7 @@ class Auth:
403, "Application service has not registered this user (%s)" % user_id 403, "Application service has not registered this user (%s)" % user_id
) )
@cancellable
async def _get_appservice_user(self, request: Request) -> Optional[Requester]: async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
""" """
Given a request, reads the request parameters to determine: Given a request, reads the request parameters to determine:
@ -523,6 +527,7 @@ class Auth:
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod @staticmethod
@cancellable
def get_access_token_from_request(request: Request) -> str: def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request. """Extracts the access_token from the request.

View File

@ -52,6 +52,7 @@ from synapse.types import (
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -124,6 +125,7 @@ class DeviceWorkerHandler:
return device return device
@cancellable
async def get_device_changes_in_shared_rooms( async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]: ) -> Collection[str]:
@ -163,6 +165,7 @@ class DeviceWorkerHandler:
@trace @trace
@measure_func("device.get_user_ids_changed") @measure_func("device.get_user_ids_changed")
@cancellable
async def get_user_ids_changed( async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken self, user_id: str, from_token: StreamToken
) -> JsonDict: ) -> JsonDict:

View File

@ -37,7 +37,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
from synapse.util import json_decoder, unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer, delay_cancellation
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING: if TYPE_CHECKING:
@ -91,6 +92,7 @@ class E2eKeysHandler:
) )
@trace @trace
@cancellable
async def query_devices( async def query_devices(
self, self,
query_body: JsonDict, query_body: JsonDict,
@ -208,22 +210,26 @@ class E2eKeysHandler:
r[user_id] = remote_queries[user_id] r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache # Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
await make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( delay_cancellation(
[ defer.gatherResults(
run_in_background( [
self._query_devices_for_destination, run_in_background(
results, self._query_devices_for_destination,
cross_signing_keys, results,
failures, cross_signing_keys,
destination, failures,
queries, destination,
timeout, queries,
) timeout,
for destination, queries in remote_queries_not_in_cache.items() )
], for destination, queries in remote_queries_not_in_cache.items()
consumeErrors=True, ],
).addErrback(unwrapFirstError) consumeErrors=True,
).addErrback(unwrapFirstError)
)
) )
ret = {"device_keys": results, "failures": failures} ret = {"device_keys": results, "failures": failures}
@ -347,6 +353,7 @@ class E2eKeysHandler:
return return
@cancellable
async def get_cross_signing_keys_from_cache( async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str] self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
@ -393,6 +400,7 @@ class E2eKeysHandler:
} }
@trace @trace
@cancellable
async def query_local_devices( async def query_local_devices(
self, query: Mapping[str, Optional[List[str]]] self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:

View File

@ -27,9 +27,9 @@ from synapse.http.servlet import (
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag from synapse.logging.opentracing import log_kv, set_tag
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -156,6 +156,7 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@cancellable
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -199,6 +200,7 @@ class KeyChangesServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)

View File

@ -36,6 +36,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker, PartialStateEventsTracker,
) )
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
from synapse.util.cancellation import cancellable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -229,6 +230,7 @@ class StateStorageController:
@trace @trace
@tag_args @tag_args
@cancellable
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, self,
event_ids: Collection[str], event_ids: Collection[str],
@ -350,6 +352,7 @@ class StateStorageController:
@trace @trace
@tag_args @tag_args
@cancellable
async def get_state_group_for_events( async def get_state_group_for_events(
self, self,
event_ids: Collection[str], event_ids: Collection[str],
@ -398,6 +401,7 @@ class StateStorageController:
event_id, room_id, prev_group, delta_ids, current_state_ids event_id, room_id, prev_group, delta_ids, current_state_ids
) )
@cancellable
async def get_current_state_ids( async def get_current_state_ids(
self, self,
room_id: str, room_id: str,

View File

@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
... ...
@trace @trace
@cancellable
async def get_user_devices_from_cache( async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]] self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key) return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
self, self,
from_key: int, from_key: int,
@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="get_min_device_lists_changes_in_room", desc="get_min_device_lists_changes_in_room",
) )
@cancellable
async def get_device_list_changes_in_rooms( async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int self, room_ids: Collection[str], from_id: int
) -> Optional[Set[str]]: ) -> Optional[Set[str]]:

View File

@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return now_stream_id, [] return now_stream_id, []
@trace @trace
@cancellable
async def get_e2e_device_keys_for_cs_api( async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]] self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]: ) -> Dict[str, Dict[str, JsonDict]]:
@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
... ...
@trace @trace
@cancellable
async def get_e2e_device_keys_and_signatures( async def get_e2e_device_keys_and_signatures(
self, self,
query_list: Collection[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return keys return keys
@cancellable
async def get_e2e_cross_signing_keys_bulk( async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]: ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
keys were not found, either their user ID will not be in the dict, keys were not found, either their user ID will not be in the dict,
or their user ID will map to None. or their user ID will map to None.
""" """
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id: if from_user_id:

View File

@ -48,6 +48,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None return int(min_depth) if min_depth is not None else None
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering( async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int self, room_id: str, stream_ordering: int
) -> List[str]: ) -> List[str]:

View File

@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore):
) -> Optional[EventBase]: ) -> Optional[EventBase]:
... ...
@cancellable
async def get_event( async def get_event(
self, self,
event_id: str, event_id: str,
@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore):
@trace @trace
@tag_args @tag_args
@cancellable
async def get_events_as_list( async def get_events_as_list(
self, self,
event_ids: Collection[str], event_ids: Collection[str],
@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore):
return events return events
@cancellable
async def _get_events_from_cache_or_db( async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, EventCacheEntry]: ) -> Dict[str, EventCacheEntry]:

View File

@ -55,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -770,6 +771,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn, _get_users_server_still_shares_room_with_txn,
) )
@cancellable
async def get_rooms_for_user( async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]: ) -> FrozenSet[str]:

View File

@ -36,6 +36,7 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -281,6 +282,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
# FIXME: how should this be cached? # FIXME: how should this be cached?
@cancellable
async def get_partial_filtered_current_state_ids( async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]: ) -> StateMap[str]:

View File

@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key return ret, key
@cancellable
async def get_membership_changes_for_user( async def get_membership_changes_for_user(
self, self,
user_id: str, user_id: str,

View File

@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.cancellation import cancellable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"get_state_group_delta", _get_state_group_delta_txn "get_state_group_delta", _get_state_group_delta_txn
) )
@cancellable
async def _get_state_groups_from_groups( async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]: ) -> Dict[int, StateMap[str]]:
@ -235,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
@cancellable
async def _get_state_for_groups( async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:

View File

@ -24,6 +24,7 @@ from synapse.logging.opentracing import trace_with_opname
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.cancellation import cancellable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +61,7 @@ class PartialStateEventsTracker:
o.callback(None) o.callback(None)
@trace_with_opname("PartialStateEventsTracker.await_full_state") @trace_with_opname("PartialStateEventsTracker.await_full_state")
@cancellable
async def await_full_state(self, event_ids: Collection[str]) -> None: async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state. """Wait for all the given events to have full state.
@ -154,6 +156,7 @@ class PartialCurrentStateTracker:
o.callback(None) o.callback(None)
@trace_with_opname("PartialCurrentStateTracker.await_full_state") @trace_with_opname("PartialCurrentStateTracker.await_full_state")
@cancellable
async def await_full_state(self, room_id: str) -> None: async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for # We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room. # partial state doesn't race when we unpartial the room.

View File

@ -52,6 +52,7 @@ from twisted.internet.interfaces import (
) )
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
@ -699,7 +700,11 @@ class StreamToken:
START: ClassVar["StreamToken"] START: ClassVar["StreamToken"]
@classmethod @classmethod
@cancellable
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
"""
Creates a RoomStreamToken from its textual representation.
"""
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
while len(keys) < len(attr.fields(cls)): while len(keys) < len(attr.fields(cls)):

View File

@ -140,6 +140,8 @@ def make_request_with_cancellation_test(
method: str, method: str,
path: str, path: str,
content: Union[bytes, str, JsonDict] = b"", content: Union[bytes, str, JsonDict] = b"",
*,
token: Optional[str] = None,
) -> FakeChannel: ) -> FakeChannel:
"""Performs a request repeatedly, disconnecting at successive `await`s, until """Performs a request repeatedly, disconnecting at successive `await`s, until
one completes. one completes.
@ -211,7 +213,13 @@ def make_request_with_cancellation_test(
with deferred_patch.patch(): with deferred_patch.patch():
# Start the request. # Start the request.
channel = make_request( channel = make_request(
reactor, site, method, path, content, await_result=False reactor,
site,
method,
path,
content,
await_result=False,
access_token=token,
) )
request = channel.request request = channel.request

View File

@ -19,6 +19,7 @@ from synapse.rest import admin
from synapse.rest.client import keys, login from synapse.rest.client import keys, login
from tests import unittest from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
class KeyQueryTestCase(unittest.HomeserverTestCase): class KeyQueryTestCase(unittest.HomeserverTestCase):
@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
Codes.BAD_JSON, Codes.BAD_JSON,
channel.result, channel.result,
) )
def test_key_query_cancellation(self) -> None:
"""
Tests that /keys/query is cancellable and does not swallow the
CancelledError.
"""
self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland")
bob = self.register_user("bob", "uncle")
channel = make_request_with_cancellation_test(
"test_key_query_cancellation",
self.reactor,
self.site,
"POST",
"/_matrix/client/r0/keys/query",
{
"device_keys": {
# Empty list means we request keys for all bob's devices
bob: [],
},
},
token=alice_token,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertIn(bob, channel.json_body["device_keys"])