Merge pull request #3384 from matrix-org/rav/rewrite_cachedlist_decorator
Rewrite cache list decoratorpull/3529/merge
commit
cab782c17e
|
@ -0,0 +1 @@
|
||||||
|
Rewrite cache list decorator
|
|
@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its
|
||||||
# whenever we are invalidated
|
# invalidate() whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
list_args = arg_dict[self.list_name]
|
list_args = arg_dict[self.list_name]
|
||||||
|
|
||||||
# cached is a dict arg -> deferred, where deferred results in a
|
|
||||||
# 2-tuple (`arg`, `result`)
|
|
||||||
results = {}
|
results = {}
|
||||||
cached_defers = {}
|
|
||||||
missing = []
|
def update_results_dict(res, arg):
|
||||||
|
results[arg] = res
|
||||||
|
|
||||||
|
# list of deferreds to wait for
|
||||||
|
cached_defers = []
|
||||||
|
|
||||||
|
missing = set()
|
||||||
|
|
||||||
# If the cache takes a single arg then that is used as the key,
|
# If the cache takes a single arg then that is used as the key,
|
||||||
# otherwise a tuple is used.
|
# otherwise a tuple is used.
|
||||||
if num_args == 1:
|
if num_args == 1:
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
return cache.get(arg, callback=invalidate_callback)
|
return arg
|
||||||
else:
|
else:
|
||||||
key = list(keyargs)
|
keylist = list(keyargs)
|
||||||
|
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
key[self.list_pos] = arg
|
keylist[self.list_pos] = arg
|
||||||
return cache.get(tuple(key), callback=invalidate_callback)
|
return tuple(keylist)
|
||||||
|
|
||||||
for arg in list_args:
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache_get(arg)
|
res = cache.get(arg_to_cache_key(arg),
|
||||||
|
callback=invalidate_callback)
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not isinstance(res, ObservableDeferred):
|
||||||
results[arg] = res
|
results[arg] = res
|
||||||
elif not res.has_succeeded():
|
elif not res.has_succeeded():
|
||||||
res = res.observe()
|
res = res.observe()
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
res.addCallback(update_results_dict, arg)
|
||||||
cached_defers[arg] = res
|
cached_defers.append(res)
|
||||||
else:
|
else:
|
||||||
results[arg] = res.get_result()
|
results[arg] = res.get_result()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.append(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
args_to_call = dict(arg_dict)
|
# we need an observable deferred for each entry in the list,
|
||||||
args_to_call[self.list_name] = missing
|
# which we put in the cache. Each deferred resolves with the
|
||||||
|
# relevant result for that key.
|
||||||
|
deferreds_map = {}
|
||||||
|
for arg in missing:
|
||||||
|
deferred = defer.Deferred()
|
||||||
|
deferreds_map[arg] = deferred
|
||||||
|
key = arg_to_cache_key(arg)
|
||||||
|
observable = ObservableDeferred(deferred)
|
||||||
|
cache.set(key, observable, callback=invalidate_callback)
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
def complete_all(res):
|
||||||
|
# the wrapped function has completed. It returns a
|
||||||
|
# a dict. We can now resolve the observable deferreds in
|
||||||
|
# the cache and update our own result map.
|
||||||
|
for e in missing:
|
||||||
|
val = res.get(e, None)
|
||||||
|
deferreds_map[e].callback(val)
|
||||||
|
results[e] = val
|
||||||
|
|
||||||
|
def errback(f):
|
||||||
|
# the wrapped function has failed. Invalidate any cache
|
||||||
|
# entries we're supposed to be populating, and fail
|
||||||
|
# their deferreds.
|
||||||
|
for e in missing:
|
||||||
|
key = arg_to_cache_key(e)
|
||||||
|
cache.invalidate(key)
|
||||||
|
deferreds_map[e].errback(f)
|
||||||
|
|
||||||
|
# return the failure, to propagate to our caller.
|
||||||
|
return f
|
||||||
|
|
||||||
|
args_to_call = dict(arg_dict)
|
||||||
|
args_to_call[self.list_name] = list(missing)
|
||||||
|
|
||||||
|
cached_defers.append(defer.maybeDeferred(
|
||||||
logcontext.preserve_fn(self.function_to_call),
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
).addCallbacks(complete_all, errback))
|
||||||
|
|
||||||
ret_d = ObservableDeferred(ret_d)
|
|
||||||
|
|
||||||
# We need to create deferreds for each arg in the list so that
|
|
||||||
# we can insert the new deferred into the cache.
|
|
||||||
for arg in missing:
|
|
||||||
observer = ret_d.observe()
|
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
|
||||||
|
|
||||||
observer = ObservableDeferred(observer)
|
|
||||||
|
|
||||||
if num_args == 1:
|
|
||||||
cache.set(
|
|
||||||
arg, observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, arg)
|
|
||||||
else:
|
|
||||||
key = list(keyargs)
|
|
||||||
key[self.list_pos] = arg
|
|
||||||
cache.set(
|
|
||||||
tuple(key), observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, tuple(key))
|
|
||||||
|
|
||||||
res = observer.observe()
|
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
|
||||||
|
|
||||||
cached_defers[arg] = res
|
|
||||||
|
|
||||||
if cached_defers:
|
if cached_defers:
|
||||||
def update_results_dict(res):
|
d = defer.gatherResults(
|
||||||
results.update(res)
|
cached_defers,
|
||||||
return results
|
|
||||||
|
|
||||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
|
||||||
list(cached_defers.values()),
|
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addCallback(update_results_dict).addErrback(
|
).addCallbacks(
|
||||||
|
lambda _: results,
|
||||||
unwrapFirstError
|
unwrapFirstError
|
||||||
))
|
)
|
||||||
|
return logcontext.make_deferred_yieldable(d)
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
||||||
cache.
|
cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache (Cache): The underlying cache to use.
|
cached_method_name (str): The name of the single-item lookup method.
|
||||||
|
This is only used to find the cache to use.
|
||||||
list_name (str): The name of the argument that is the list to use to
|
list_name (str): The name of the argument that is the list to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args (int): Number of arguments to use as the key in the cache
|
num_args (int): Number of arguments to use as the key in the cache
|
||||||
|
|
|
@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
r = yield obj.fn(2, 3)
|
r = yield obj.fn(2, 3)
|
||||||
self.assertEqual(r, 'chips')
|
self.assertEqual(r, 'chips')
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.request = "c1"
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
d1 = obj.list_fn([10, 20], 2)
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel,
|
||||||
|
)
|
||||||
|
r = yield d1
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
c1
|
||||||
|
)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = {30: 'peas'}
|
||||||
|
r = yield obj.list_fn([20, 30], 2)
|
||||||
|
obj.mock.assert_called_once_with([30], 2)
|
||||||
|
self.assertEqual(r, {20: 'chips', 30: 'peas'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# all the values should now be cached
|
||||||
|
r = yield obj.fn(10, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(20, 2)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
r = yield obj.fn(30, 2)
|
||||||
|
self.assertEqual(r, 'peas')
|
||||||
|
r = yield obj.list_fn([10, 20, 30], 2)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate(self):
|
||||||
|
"""Make sure that invalidation callbacks are called."""
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
invalidate0 = mock.Mock()
|
||||||
|
invalidate1 = mock.Mock()
|
||||||
|
|
||||||
|
# cache miss
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# cache hit
|
||||||
|
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
|
||||||
|
|
||||||
|
invalidate0.assert_not_called()
|
||||||
|
invalidate1.assert_not_called()
|
||||||
|
|
||||||
|
# now if we invalidate the keys, both invalidations should get called
|
||||||
|
obj.fn.invalidate((10, 2))
|
||||||
|
invalidate0.assert_called_once()
|
||||||
|
invalidate1.assert_called_once()
|
||||||
|
|
Loading…
Reference in New Issue