Merge pull request #212 from matrix-org/erikj/cache_deferreds

Make CacheDescriptor cache deferreds rather than the deferreds' values
pull/217/head
Erik Johnston 2015-08-07 19:28:05 +01:00
commit 06218ab125
3 changed files with 47 additions and 19 deletions

View File

@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@ -131,6 +132,9 @@ class Cache(object):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@ -173,33 +177,49 @@ class CacheDescriptor(object):
)
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try:
cached_result = cache.get(*keyargs)
cached_result_d = cache.get(*keyargs)
observer = cached_result_d.observe()
if DEBUG_CACHES:
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield self.function_to_call(obj, *args, **kwargs)
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
cache.invalidate(*keyargs)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=False)
cache.update(sequence, *(keyargs + [ret]))
defer.returnValue(ret)
return ret.observe()
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all

View File

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set())
def callback(r):
self._result = (True, r)
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r
def errback(f):
self._result = (False, f)
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value):
setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View File

@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
return d
a = A()
a.func.prefill("foo", 123)
a.func.prefill("foo", ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)