Move all the caches into their own package, synapse.util.caches

pull/219/head
Erik Johnston 2015-08-11 17:59:32 +01:00
parent 53a817518b
commit 2df8dd9b37
23 changed files with 408 additions and 352 deletions

View File

@ -23,7 +23,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
from synapse.util.expiringcache import ExpiringCache
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import synapse.metrics

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes

View File

@ -15,27 +15,22 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
from synapse.util.dictionary_cache import DictionaryCache
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple, OrderedDict
from collections import namedtuple
import functools
import inspect
import sys
import time
import threading
DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return val
cache_counter.inc_misses(self.name)
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise ValueError("keyargs must be a tuple.")
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class CacheListDescriptor(object):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r[arg], arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.errors import SynapseError

View File

@ -15,7 +15,8 @@
from twisted.internet import defer
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from syutil.base64util import encode_base64
import logging

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore, cachedInlineCallbacks
from _base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cachedInlineCallbacks
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
import logging

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cachedInlineCallbacks
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
class RegistrationStore(SQLBaseStore):

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cachedInlineCallbacks
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import collections
import logging

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import Membership
from synapse.types import UserID

View File

@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import (
cached, cachedInlineCallbacks, cachedList
)
from twisted.internet import defer

View File

@ -35,7 +35,8 @@ what sort order was used:
from twisted.internet import defer
from ._base import SQLBaseStore, cachedInlineCallbacks
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from collections import namedtuple

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
import synapse.metrics
from twisted.internet import defer
from collections import OrderedDict
import functools
import inspect
import threading
logger = logging.getLogger(__name__)
DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return val
cache_counter.inc_misses(self.name)
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise ValueError("keyargs must be a tuple.")
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class CacheListDescriptor(object):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r[arg], arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.lrucache import LruCache
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
import threading
import logging

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached
from synapse.util.caches.descriptors import Cache, cached
class CacheTestCase(unittest.TestCase):

View File

@ -17,7 +17,7 @@
from twisted.internet import defer
from tests import unittest
from synapse.util.dictionary_cache import DictionaryCache
from synapse.util.caches.dictionary_cache import DictionaryCache
class DictCacheTestCase(unittest.TestCase):

View File

@ -16,7 +16,7 @@
from .. import unittest
from synapse.util.lrucache import LruCache
from synapse.util.caches.lrucache import LruCache
class LruCacheTestCase(unittest.TestCase):
@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
self.assertEquals(cache.pop("key"), 1)
self.assertEquals(cache.pop("key"), None)