Add cancellation support to `@cached` and `@cachedList` decorators (#12183)
These decorators mostly support cancellation already. Add cancellation tests and fix use of finished logging contexts by delaying cancellation, as suggested by @erikjohnston. Signed-off-by: Sean Quah <seanq@element.io>pull/12240/head
parent
605d161d7d
commit
2fcf4b3f6c
|
@ -0,0 +1 @@
|
|||
Add cancellation support to `@cached` and `@cachedList` decorators.
|
|
@ -41,6 +41,7 @@ from twisted.python.failure import Failure
|
|||
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
|
@ -350,6 +351,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
ret = delay_cancellation(ret)
|
||||
|
||||
return make_deferred_yieldable(ret)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
|
@ -510,6 +516,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
||||
lambda _: results, unwrapFirstError
|
||||
)
|
||||
if missing:
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
d = delay_cancellation(d)
|
||||
return make_deferred_yieldable(d)
|
||||
else:
|
||||
return defer.succeed(results)
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import Set
|
|||
from unittest import mock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.logging.context import (
|
||||
|
@ -28,7 +28,7 @@ from synapse.logging.context import (
|
|||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.util.caches import descriptors
|
||||
from synapse.util.caches.descriptors import cached, lru_cache
|
||||
from synapse.util.caches.descriptors import cached, cachedList, lru_cache
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import get_awaitable_result
|
||||
|
@ -493,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
obj.invalidate()
|
||||
top_invalidate.assert_called_once()
|
||||
|
||||
def test_cancel(self):
|
||||
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
@cached()
|
||||
async def fn(self, arg1):
|
||||
await complete_lookup
|
||||
return str(arg1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
d1 = obj.fn(123)
|
||||
d2 = obj.fn(123)
|
||||
self.assertFalse(d1.called)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
# Cancel `d1`, which is the lookup that caused `fn` to run.
|
||||
d1.cancel()
|
||||
|
||||
# `d2` should complete normally.
|
||||
complete_lookup.callback(None)
|
||||
self.failureResultOf(d1, CancelledError)
|
||||
self.assertEqual(d2.result, "123")
|
||||
|
||||
def test_cancel_logcontexts(self):
|
||||
"""Test that cancellation does not break logcontexts.
|
||||
|
||||
* The `CancelledError` must be raised with the correct logcontext.
|
||||
* The inner lookup must not resume with a finished logcontext.
|
||||
* The inner lookup must not restore a finished logcontext when done.
|
||||
"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
inner_context_was_finished = False
|
||||
|
||||
@cached()
|
||||
async def fn(self, arg1):
|
||||
await make_deferred_yieldable(complete_lookup)
|
||||
self.inner_context_was_finished = current_context().finished
|
||||
return str(arg1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
async def do_lookup():
|
||||
with LoggingContext("c1") as c1:
|
||||
try:
|
||||
await obj.fn(123)
|
||||
self.fail("No CancelledError thrown")
|
||||
except CancelledError:
|
||||
self.assertEqual(
|
||||
current_context(),
|
||||
c1,
|
||||
"CancelledError was not raised with the correct logcontext",
|
||||
)
|
||||
# suppress the error and succeed
|
||||
|
||||
d = defer.ensureDeferred(do_lookup())
|
||||
d.cancel()
|
||||
|
||||
complete_lookup.callback(None)
|
||||
self.successResultOf(d)
|
||||
self.assertFalse(
|
||||
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||
)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
|
||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||
"""More tests for @cached
|
||||
|
@ -865,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
|||
obj.fn.invalidate((10, 2))
|
||||
invalidate0.assert_called_once()
|
||||
invalidate1.assert_called_once()
|
||||
|
||||
def test_cancel(self):
|
||||
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
@cached()
|
||||
def fn(self, arg1):
|
||||
pass
|
||||
|
||||
@cachedList(cached_method_name="fn", list_name="args")
|
||||
async def list_fn(self, args):
|
||||
await complete_lookup
|
||||
return {arg: str(arg) for arg in args}
|
||||
|
||||
obj = Cls()
|
||||
|
||||
d1 = obj.list_fn([123, 456])
|
||||
d2 = obj.list_fn([123, 456, 789])
|
||||
self.assertFalse(d1.called)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
d1.cancel()
|
||||
|
||||
# `d2` should complete normally.
|
||||
complete_lookup.callback(None)
|
||||
self.failureResultOf(d1, CancelledError)
|
||||
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
|
||||
|
||||
def test_cancel_logcontexts(self):
|
||||
"""Test that cancellation does not break logcontexts.
|
||||
|
||||
* The `CancelledError` must be raised with the correct logcontext.
|
||||
* The inner lookup must not resume with a finished logcontext.
|
||||
* The inner lookup must not restore a finished logcontext when done.
|
||||
"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
inner_context_was_finished = False
|
||||
|
||||
@cached()
|
||||
def fn(self, arg1):
|
||||
pass
|
||||
|
||||
@cachedList(cached_method_name="fn", list_name="args")
|
||||
async def list_fn(self, args):
|
||||
await make_deferred_yieldable(complete_lookup)
|
||||
self.inner_context_was_finished = current_context().finished
|
||||
return {arg: str(arg) for arg in args}
|
||||
|
||||
obj = Cls()
|
||||
|
||||
async def do_lookup():
|
||||
with LoggingContext("c1") as c1:
|
||||
try:
|
||||
await obj.list_fn([123])
|
||||
self.fail("No CancelledError thrown")
|
||||
except CancelledError:
|
||||
self.assertEqual(
|
||||
current_context(),
|
||||
c1,
|
||||
"CancelledError was not raised with the correct logcontext",
|
||||
)
|
||||
# suppress the error and succeed
|
||||
|
||||
d = defer.ensureDeferred(do_lookup())
|
||||
d.cancel()
|
||||
|
||||
complete_lookup.callback(None)
|
||||
self.successResultOf(d)
|
||||
self.assertFalse(
|
||||
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||
)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
|
Loading…
Reference in New Issue