Speed up cache size calculation

Instead of calculating the size of the cache repeatedly, which can take
a long time now that it can use a callback, instead cache the size and
update that on insertion and deletion.

This requires changing the cache descriptors to have two caches, one for
pending deferreds and the other for the actual values. There's no reason
to evict from the pending deferreds as they won't take up any more
memory.
pull/1815/head
Erik Johnston 2017-01-17 11:18:13 +00:00
parent f2f179dce2
commit f85b6ca494
7 changed files with 148 additions and 62 deletions

View File

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from synapse.util.caches.treecache import TreeCache, popped_to_iterator
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@ -42,11 +42,23 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
def deferred_size(deferred):
if deferred.called:
return len(deferred.result)
else:
return 1
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object):
@ -58,13 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=deferred_size if iterable else None,
size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@ -84,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -96,15 +119,39 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value, callback=None):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value, callback=callback)
entry = CacheEntry(
deferred=value,
sequence=self.sequence,
callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@ -116,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key):
@ -127,6 +178,12 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
val = self._pending_deferred_cache.pop(key, None)
if val is not None:
entry_dict, _ = val
for entry in popped_to_iterator(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@ -254,11 +311,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@ -272,7 +324,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@ -370,7 +422,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@ -393,8 +444,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(
sequence, tuple(key), observer,
cache.set(
tuple(key), observer,
callback=invalidate_callback
)

View File

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name
self.sequence = 0

View File

@ -56,6 +56,8 @@ class ExpiringCache(object):
self.iterable = iterable
self._size_estimate = 0
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
@ -70,9 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
if self.iterable:
self._size_estimate += len(value)
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
self._cache.popitem(last=False)
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key):
try:
@ -109,7 +116,9 @@ class ExpiringCache(object):
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
@ -118,7 +127,7 @@ class ExpiringCache(object):
def __len__(self):
if self.iterable:
return sum(len(value.value) for value in self._cache.itervalues())
return self._size_estimate
else:
return len(self._cache)

View File

@ -58,12 +58,6 @@ class LruCache(object):
lock = threading.Lock()
def cache_len():
if size_callback is not None:
return sum(size_callback(node.value) for node in cache.itervalues())
else:
return len(cache)
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
@ -78,6 +72,16 @@ class LruCache(object):
return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
@ -86,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node
cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@ -104,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None, callback=None):
def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
node.callbacks.update(callbacks)
return node.value
else:
return default
@synchronized
def cache_set(key, value, callback=None):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
if value != node.value:
@ -128,17 +137,16 @@ class LruCache(object):
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node)
node.value = value
else:
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
add_node(key, value, set(callbacks))
evict()

View File

@ -65,12 +65,24 @@ class TreeCache(object):
return popped
def values(self):
return [e.value for e in self.root.values()]
return list(popped_to_iterator(self.root))
def __len__(self):
return self.size
def popped_to_iterator(d):
if isinstance(d, dict):
for value_d in d.itervalues():
for value in popped_to_iterator(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object):
__slots__ = ["value"]

View File

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
@cached(max_entries=2)
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key):
callcount[0] += 1
return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)

View File

@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", "value")
@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value2")
@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", [m])
self.assertFalse(m.called)
cache.set("key", "value")
@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", [m])
self.assertFalse(m.called)
cache.pop("key")
@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
cache.set(("a", "1"), "value", [m1])
cache.set(("a", "2"), "value", [m2])
cache.set(("b", "1"), "value", [m3])
cache.set(("b", "2"), "value", [m4])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", [m1])
cache.set("key2", "value", [m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", [m1])
cache.set("key2", "value", [m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
cache.set("key3", "value", [m3])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
cache.set("key1", "value", [m1])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)