From e3417a06e23c532e6502bdcdcaedac826e231d69 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 May 2017 15:04:42 +0100 Subject: [PATCH] Update list cache to handle one arg case We update the normal cache descriptors to handle caches with a single argument specially so that the key wasn't a 1-tuple. We need to update the cache list to be aware of this. --- synapse/util/caches/descriptors.py | 48 ++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 48dcbafeef..77a0d8e35d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.invalidate_all = cache.invalidate_all wrapped.cache = cache + wrapped.num_args = self.num_args obj.__dict__[self.orig.__name__] = wrapped @@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase): ) def __get__(self, obj, objtype=None): - - cache = getattr(obj, self.cached_method_name).cache + cached_method = getattr(obj, self.cached_method_name) + cache = cached_method.cache + num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args, **kwargs): @@ -470,11 +472,14 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers = {} missing = [] for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg - try: - res = cache.get(tuple(key), callback=invalidate_callback) + if num_args == 1: + res = cache.get(arg, callback=invalidate_callback) + else: + key = list(keyargs) + key[self.list_pos] = arg + res = cache.get(tuple(key), callback=invalidate_callback) + if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -505,17 +510,28 @@ class CacheListDescriptor(_CacheDescriptorBase): observer = ObservableDeferred(observer) - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) + if num_args == 1: + cache.set( + arg, observer, + callback=invalidate_callback + ) - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, arg) + else: + key = list(keyargs) + key[self.list_pos] = arg + cache.set( + tuple(key), observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg)