Update `delay_cancellation` to accept any awaitable (#12468)
This will mainly be useful when dealing with module callbacks, which are all typed as returning `Awaitable`s instead of coroutines or `Deferred`s. Signed-off-by: Sean Quah <seanq@element.io>pull/12529/head
parent
b82fff66df
commit
a50fb411b3
|
@ -0,0 +1 @@
|
|||
Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.
|
|
@ -41,7 +41,6 @@ from prometheus_client import Histogram
|
|||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
|
@ -794,7 +793,7 @@ class DatabasePool:
|
|||
# We also wait until everything above is done before releasing the
|
||||
# `CancelledError`, so that logging contexts won't get used after they have been
|
||||
# finished.
|
||||
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
|
||||
return await delay_cancellation(_runInteraction())
|
||||
|
||||
async def runWithConnection(
|
||||
self,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import collections
|
||||
import inspect
|
||||
import itertools
|
||||
|
@ -25,6 +26,7 @@ from typing import (
|
|||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
|
@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
|||
return new_deferred
|
||||
|
||||
|
||||
def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
"""Delay cancellation of a `Deferred` until it resolves.
|
||||
@overload
|
||||
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
||||
...
|
||||
|
||||
|
||||
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
||||
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
|
||||
|
||||
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
||||
resolve with a `CancelledError` until the original `Deferred` resolves.
|
||||
resolve with a `CancelledError` until the original awaitable resolves.
|
||||
|
||||
Args:
|
||||
deferred: The `Deferred` to protect against cancellation. May optionally follow
|
||||
the Synapse logcontext rules.
|
||||
deferred: The coroutine or `Deferred` to protect against cancellation. May
|
||||
optionally follow the Synapse logcontext rules.
|
||||
|
||||
Returns:
|
||||
A new `Deferred`, which will contain the result of the original `Deferred`.
|
||||
The new `Deferred` will not propagate cancellation through to the original.
|
||||
When cancelled, the new `Deferred` will wait until the original `Deferred`
|
||||
resolves before failing with a `CancelledError`.
|
||||
A new `Deferred`, which will contain the result of the original coroutine or
|
||||
`Deferred`. The new `Deferred` will not propagate cancellation through to the
|
||||
original coroutine or `Deferred`.
|
||||
|
||||
The new `Deferred` will follow the Synapse logcontext rules if `deferred`
|
||||
When cancelled, the new `Deferred` will wait until the original coroutine or
|
||||
`Deferred` resolves before failing with a `CancelledError`.
|
||||
|
||||
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
|
||||
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
|
||||
wrapped with `make_deferred_yieldable`.
|
||||
"""
|
||||
|
||||
# First, convert the awaitable into a `Deferred`.
|
||||
if isinstance(awaitable, defer.Deferred):
|
||||
deferred = awaitable
|
||||
elif asyncio.iscoroutine(awaitable):
|
||||
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
|
||||
# type-checking, but we'd need Twisted >= 21.2.
|
||||
deferred = defer.ensureDeferred(awaitable)
|
||||
else:
|
||||
# We have no idea what to do with this awaitable.
|
||||
# We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
|
||||
# `make_awaitable`, and let the caller `await` it normally.
|
||||
return awaitable
|
||||
|
||||
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
|
||||
# before the new deferred is cancelled, we `pause` it to stop the cancellation
|
||||
# propagating. we then `unpause` it once the wrapped deferred completes, to
|
||||
|
|
|
@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
|
|||
class DelayCancellationTests(TestCase):
|
||||
"""Tests for the `delay_cancellation` function."""
|
||||
|
||||
def test_cancellation(self):
|
||||
def test_deferred_cancellation(self):
|
||||
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||
deferred: "Deferred[str]" = Deferred()
|
||||
wrapper_deferred = delay_cancellation(deferred)
|
||||
|
@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
|
|||
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
|
||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||
|
||||
def test_coroutine_cancellation(self):
|
||||
"""Test that cancellation of the new `Deferred` waits for the original."""
|
||||
blocking_deferred: "Deferred[None]" = Deferred()
|
||||
completion_deferred: "Deferred[None]" = Deferred()
|
||||
|
||||
async def task():
|
||||
await blocking_deferred
|
||||
completion_deferred.callback(None)
|
||||
# Raise an exception. Twisted should consume it, otherwise unwanted
|
||||
# tracebacks will be printed in logs.
|
||||
raise ValueError("abc")
|
||||
|
||||
wrapper_deferred = delay_cancellation(task())
|
||||
|
||||
# Cancel the new `Deferred`.
|
||||
wrapper_deferred.cancel()
|
||||
self.assertNoResult(wrapper_deferred)
|
||||
self.assertFalse(
|
||||
blocking_deferred.called, "Cancellation was propagated too deep"
|
||||
)
|
||||
self.assertFalse(completion_deferred.called)
|
||||
|
||||
# Unblock the task.
|
||||
blocking_deferred.callback(None)
|
||||
self.assertTrue(completion_deferred.called)
|
||||
|
||||
# Now that the original coroutine has failed, we should get a `CancelledError`.
|
||||
self.failureResultOf(wrapper_deferred, CancelledError)
|
||||
|
||||
def test_suppresses_second_cancellation(self):
|
||||
"""Test that a second cancellation is suppressed.
|
||||
|
||||
|
@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
|
|||
async def outer():
|
||||
with LoggingContext("c") as c:
|
||||
try:
|
||||
await delay_cancellation(defer.ensureDeferred(inner()))
|
||||
await delay_cancellation(inner())
|
||||
self.fail("`CancelledError` was not raised")
|
||||
except CancelledError:
|
||||
self.assertEqual(c, current_context())
|
||||
|
|
Loading…
Reference in New Issue