Cancel the processing of key query requests when they time out. (#13680)
parent
c2fe48a6ff
commit
d3d9ca156e
|
@ -0,0 +1 @@
|
||||||
|
Cancel the processing of key query requests when they time out.
|
|
@ -38,6 +38,7 @@ from synapse.logging.opentracing import (
|
||||||
trace,
|
trace,
|
||||||
)
|
)
|
||||||
from synapse.types import Requester, create_requester
|
from synapse.types import Requester, create_requester
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -118,6 +119,7 @@ class Auth:
|
||||||
errcode=Codes.NOT_JOINED,
|
errcode=Codes.NOT_JOINED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
@ -166,6 +168,7 @@ class Auth:
|
||||||
parent_span.set_tag("appservice_id", requester.app_service.id)
|
parent_span.set_tag("appservice_id", requester.app_service.id)
|
||||||
return requester
|
return requester
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def _wrapped_get_user_by_req(
|
async def _wrapped_get_user_by_req(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
@ -281,6 +284,7 @@ class Auth:
|
||||||
403, "Application service has not registered this user (%s)" % user_id
|
403, "Application service has not registered this user (%s)" % user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
|
async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
|
||||||
"""
|
"""
|
||||||
Given a request, reads the request parameters to determine:
|
Given a request, reads the request parameters to determine:
|
||||||
|
@ -523,6 +527,7 @@ class Auth:
|
||||||
return bool(query_params) or bool(auth_headers)
|
return bool(query_params) or bool(auth_headers)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@cancellable
|
||||||
def get_access_token_from_request(request: Request) -> str:
|
def get_access_token_from_request(request: Request) -> str:
|
||||||
"""Extracts the access_token from the request.
|
"""Extracts the access_token from the request.
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,7 @@ from synapse.types import (
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
@ -124,6 +125,7 @@ class DeviceWorkerHandler:
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_device_changes_in_shared_rooms(
|
async def get_device_changes_in_shared_rooms(
|
||||||
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
|
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
|
||||||
) -> Collection[str]:
|
) -> Collection[str]:
|
||||||
|
@ -163,6 +165,7 @@ class DeviceWorkerHandler:
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@measure_func("device.get_user_ids_changed")
|
@measure_func("device.get_user_ids_changed")
|
||||||
|
@cancellable
|
||||||
async def get_user_ids_changed(
|
async def get_user_ids_changed(
|
||||||
self, user_id: str, from_token: StreamToken
|
self, user_id: str, from_token: StreamToken
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
|
|
@ -37,7 +37,8 @@ from synapse.types import (
|
||||||
get_verify_key_from_cross_signing_key,
|
get_verify_key_from_cross_signing_key,
|
||||||
)
|
)
|
||||||
from synapse.util import json_decoder, unwrapFirstError
|
from synapse.util import json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer, delay_cancellation
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -91,6 +92,7 @@ class E2eKeysHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@cancellable
|
||||||
async def query_devices(
|
async def query_devices(
|
||||||
self,
|
self,
|
||||||
query_body: JsonDict,
|
query_body: JsonDict,
|
||||||
|
@ -208,22 +210,26 @@ class E2eKeysHandler:
|
||||||
r[user_id] = remote_queries[user_id]
|
r[user_id] = remote_queries[user_id]
|
||||||
|
|
||||||
# Now fetch any devices that we don't have in our cache
|
# Now fetch any devices that we don't have in our cache
|
||||||
|
# TODO It might make sense to propagate cancellations into the
|
||||||
|
# deferreds which are querying remote homeservers.
|
||||||
await make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
delay_cancellation(
|
||||||
[
|
defer.gatherResults(
|
||||||
run_in_background(
|
[
|
||||||
self._query_devices_for_destination,
|
run_in_background(
|
||||||
results,
|
self._query_devices_for_destination,
|
||||||
cross_signing_keys,
|
results,
|
||||||
failures,
|
cross_signing_keys,
|
||||||
destination,
|
failures,
|
||||||
queries,
|
destination,
|
||||||
timeout,
|
queries,
|
||||||
)
|
timeout,
|
||||||
for destination, queries in remote_queries_not_in_cache.items()
|
)
|
||||||
],
|
for destination, queries in remote_queries_not_in_cache.items()
|
||||||
consumeErrors=True,
|
],
|
||||||
).addErrback(unwrapFirstError)
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = {"device_keys": results, "failures": failures}
|
ret = {"device_keys": results, "failures": failures}
|
||||||
|
@ -347,6 +353,7 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_cross_signing_keys_from_cache(
|
async def get_cross_signing_keys_from_cache(
|
||||||
self, query: Iterable[str], from_user_id: Optional[str]
|
self, query: Iterable[str], from_user_id: Optional[str]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
|
@ -393,6 +400,7 @@ class E2eKeysHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@cancellable
|
||||||
async def query_local_devices(
|
async def query_local_devices(
|
||||||
self, query: Mapping[str, Optional[List[str]]]
|
self, query: Mapping[str, Optional[List[str]]]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
|
|
|
@ -27,9 +27,9 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import log_kv, set_tag
|
from synapse.logging.opentracing import log_kv, set_tag
|
||||||
|
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||||
from synapse.types import JsonDict, StreamToken
|
from synapse.types import JsonDict, StreamToken
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from ._base import client_patterns, interactive_auth_handler
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -156,6 +156,7 @@ class KeyQueryServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
@ -199,6 +200,7 @@ class KeyChangesServlet(RestServlet):
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from synapse.storage.util.partial_state_events_tracker import (
|
||||||
PartialStateEventsTracker,
|
PartialStateEventsTracker,
|
||||||
)
|
)
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -229,6 +230,7 @@ class StateStorageController:
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@tag_args
|
@tag_args
|
||||||
|
@cancellable
|
||||||
async def get_state_ids_for_events(
|
async def get_state_ids_for_events(
|
||||||
self,
|
self,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
|
@ -350,6 +352,7 @@ class StateStorageController:
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@tag_args
|
@tag_args
|
||||||
|
@cancellable
|
||||||
async def get_state_group_for_events(
|
async def get_state_group_for_events(
|
||||||
self,
|
self,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
|
@ -398,6 +401,7 @@ class StateStorageController:
|
||||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_current_state_ids(
|
async def get_current_state_ids(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
|
|
@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.stringutils import shortstr
|
from synapse.util.stringutils import shortstr
|
||||||
|
|
||||||
|
@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||||
...
|
...
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@cancellable
|
||||||
async def get_user_devices_from_cache(
|
async def get_user_devices_from_cache(
|
||||||
self, query_list: List[Tuple[str, Optional[str]]]
|
self, query_list: List[Tuple[str, Optional[str]]]
|
||||||
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
||||||
|
@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
self,
|
self,
|
||||||
from_key: int,
|
from_key: int,
|
||||||
|
@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||||
desc="get_min_device_lists_changes_in_room",
|
desc="get_min_device_lists_changes_in_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_device_list_changes_in_rooms(
|
async def get_device_list_changes_in_rooms(
|
||||||
self, room_ids: Collection[str], from_id: int
|
self, room_ids: Collection[str], from_id: int
|
||||||
) -> Optional[Set[str]]:
|
) -> Optional[Set[str]]:
|
||||||
|
|
|
@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@cancellable
|
||||||
async def get_e2e_device_keys_for_cs_api(
|
async def get_e2e_device_keys_for_cs_api(
|
||||||
self, query_list: List[Tuple[str, Optional[str]]]
|
self, query_list: List[Tuple[str, Optional[str]]]
|
||||||
) -> Dict[str, Dict[str, JsonDict]]:
|
) -> Dict[str, Dict[str, JsonDict]]:
|
||||||
|
@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
...
|
...
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@cancellable
|
||||||
async def get_e2e_device_keys_and_signatures(
|
async def get_e2e_device_keys_and_signatures(
|
||||||
self,
|
self,
|
||||||
query_list: Collection[Tuple[str, Optional[str]]],
|
query_list: Collection[Tuple[str, Optional[str]]],
|
||||||
|
@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_e2e_cross_signing_keys_bulk(
|
async def get_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: List[str], from_user_id: Optional[str] = None
|
self, user_ids: List[str], from_user_id: Optional[str] = None
|
||||||
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||||
|
@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
keys were not found, either their user ID will not be in the dict,
|
keys were not found, either their user ID will not be in the dict,
|
||||||
or their user ID will map to None.
|
or their user ID will map to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||||
|
|
||||||
if from_user_id:
|
if from_user_id:
|
||||||
|
|
|
@ -48,6 +48,7 @@ from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
return int(min_depth) if min_depth is not None else None
|
return int(min_depth) if min_depth is not None else None
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_forward_extremities_for_room_at_stream_ordering(
|
async def get_forward_extremities_for_room_at_stream_ordering(
|
||||||
self, room_id: str, stream_ordering: int
|
self, room_id: str, stream_ordering: int
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
|
|
@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import AsyncLruCache
|
from synapse.util.caches.lrucache import AsyncLruCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
) -> Optional[EventBase]:
|
) -> Optional[EventBase]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_event(
|
async def get_event(
|
||||||
self,
|
self,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
|
@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@tag_args
|
@tag_args
|
||||||
|
@cancellable
|
||||||
async def get_events_as_list(
|
async def get_events_as_list(
|
||||||
self,
|
self,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
|
@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def _get_events_from_cache_or_db(
|
async def _get_events_from_cache_or_db(
|
||||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
self, event_ids: Iterable[str], allow_rejected: bool = False
|
||||||
) -> Dict[str, EventCacheEntry]:
|
) -> Dict[str, EventCacheEntry]:
|
||||||
|
|
|
@ -55,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
@ -770,6 +771,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
_get_users_server_still_shares_room_with_txn,
|
_get_users_server_still_shares_room_with_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_rooms_for_user(
|
async def get_rooms_for_user(
|
||||||
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
|
|
|
@ -36,6 +36,7 @@ from synapse.storage.state import StateFilter
|
||||||
from synapse.types import JsonDict, JsonMapping, StateMap
|
from synapse.types import JsonDict, JsonMapping, StateMap
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -281,6 +282,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: how should this be cached?
|
# FIXME: how should this be cached?
|
||||||
|
@cancellable
|
||||||
async def get_partial_filtered_current_state_ids(
|
async def get_partial_filtered_current_state_ids(
|
||||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
|
|
|
@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.types import PersistedEventPosition, RoomStreamToken
|
from synapse.types import PersistedEventPosition, RoomStreamToken
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return ret, key
|
return ret, key
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def get_membership_changes_for_user(
|
async def get_membership_changes_for_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
|
@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
"get_state_group_delta", _get_state_group_delta_txn
|
"get_state_group_delta", _get_state_group_delta_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def _get_state_groups_from_groups(
|
async def _get_state_groups_from_groups(
|
||||||
self, groups: List[int], state_filter: StateFilter
|
self, groups: List[int], state_filter: StateFilter
|
||||||
) -> Dict[int, StateMap[str]]:
|
) -> Dict[int, StateMap[str]]:
|
||||||
|
@ -235,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
|
|
||||||
return state_filter.filter_state(state_dict_ids), not missing_types
|
return state_filter.filter_state(state_dict_ids), not missing_types
|
||||||
|
|
||||||
|
@cancellable
|
||||||
async def _get_state_for_groups(
|
async def _get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Dict[int, MutableStateMap[str]]:
|
) -> Dict[int, MutableStateMap[str]]:
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.logging.opentracing import trace_with_opname
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -60,6 +61,7 @@ class PartialStateEventsTracker:
|
||||||
o.callback(None)
|
o.callback(None)
|
||||||
|
|
||||||
@trace_with_opname("PartialStateEventsTracker.await_full_state")
|
@trace_with_opname("PartialStateEventsTracker.await_full_state")
|
||||||
|
@cancellable
|
||||||
async def await_full_state(self, event_ids: Collection[str]) -> None:
|
async def await_full_state(self, event_ids: Collection[str]) -> None:
|
||||||
"""Wait for all the given events to have full state.
|
"""Wait for all the given events to have full state.
|
||||||
|
|
||||||
|
@ -154,6 +156,7 @@ class PartialCurrentStateTracker:
|
||||||
o.callback(None)
|
o.callback(None)
|
||||||
|
|
||||||
@trace_with_opname("PartialCurrentStateTracker.await_full_state")
|
@trace_with_opname("PartialCurrentStateTracker.await_full_state")
|
||||||
|
@cancellable
|
||||||
async def await_full_state(self, room_id: str) -> None:
|
async def await_full_state(self, room_id: str) -> None:
|
||||||
# We add the deferred immediately so that the DB call to check for
|
# We add the deferred immediately so that the DB call to check for
|
||||||
# partial state doesn't race when we unpartial the room.
|
# partial state doesn't race when we unpartial the room.
|
||||||
|
|
|
@ -52,6 +52,7 @@ from twisted.internet.interfaces import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name
|
from synapse.util.stringutils import parse_and_validate_server_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -699,7 +700,11 @@ class StreamToken:
|
||||||
START: ClassVar["StreamToken"]
|
START: ClassVar["StreamToken"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@cancellable
|
||||||
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
|
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
|
||||||
|
"""
|
||||||
|
Creates a RoomStreamToken from its textual representation.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
keys = string.split(cls._SEPARATOR)
|
keys = string.split(cls._SEPARATOR)
|
||||||
while len(keys) < len(attr.fields(cls)):
|
while len(keys) < len(attr.fields(cls)):
|
||||||
|
|
|
@ -140,6 +140,8 @@ def make_request_with_cancellation_test(
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
content: Union[bytes, str, JsonDict] = b"",
|
content: Union[bytes, str, JsonDict] = b"",
|
||||||
|
*,
|
||||||
|
token: Optional[str] = None,
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""Performs a request repeatedly, disconnecting at successive `await`s, until
|
"""Performs a request repeatedly, disconnecting at successive `await`s, until
|
||||||
one completes.
|
one completes.
|
||||||
|
@ -211,7 +213,13 @@ def make_request_with_cancellation_test(
|
||||||
with deferred_patch.patch():
|
with deferred_patch.patch():
|
||||||
# Start the request.
|
# Start the request.
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
reactor, site, method, path, content, await_result=False
|
reactor,
|
||||||
|
site,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
content,
|
||||||
|
await_result=False,
|
||||||
|
access_token=token,
|
||||||
)
|
)
|
||||||
request = channel.request
|
request = channel.request
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from synapse.rest import admin
|
||||||
from synapse.rest.client import keys, login
|
from synapse.rest.client import keys, login
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.http.server._base import make_request_with_cancellation_test
|
||||||
|
|
||||||
|
|
||||||
class KeyQueryTestCase(unittest.HomeserverTestCase):
|
class KeyQueryTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
|
||||||
Codes.BAD_JSON,
|
Codes.BAD_JSON,
|
||||||
channel.result,
|
channel.result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_key_query_cancellation(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that /keys/query is cancellable and does not swallow the
|
||||||
|
CancelledError.
|
||||||
|
"""
|
||||||
|
self.register_user("alice", "wonderland")
|
||||||
|
alice_token = self.login("alice", "wonderland")
|
||||||
|
|
||||||
|
bob = self.register_user("bob", "uncle")
|
||||||
|
|
||||||
|
channel = make_request_with_cancellation_test(
|
||||||
|
"test_key_query_cancellation",
|
||||||
|
self.reactor,
|
||||||
|
self.site,
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/keys/query",
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
# Empty list means we request keys for all bob's devices
|
||||||
|
bob: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
token=alice_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||||
|
self.assertIn(bob, channel.json_body["device_keys"])
|
||||||
|
|
Loading…
Reference in New Issue