Implement and use an @lru_cache decorator (#8595)

We don't always need the full power of a DeferredCache.
erikj/release_script
Richard van der Hoff 2020-10-30 11:43:17 +00:00 committed by GitHub
parent fd7c743445
commit cbc82aa09f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 272 additions and 61 deletions

1
changelog.d/8595.misc Normal file
View File

@ -0,0 +1 @@
Implement and use an @lru_cache decorator.

View File

@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple
import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.constants import EventTypes, Membership, RelationTypes
@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules dict of user_id -> push_rules
""" """
room_id = event.room_id room_id = event.room_id
rules_for_room = await self._get_rules_for_room(room_id) rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = await rules_for_room.get_rules(event, context) rules_by_user = await rules_for_room.get_rules(event, context)
@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
return rules_by_user return rules_by_user
@cached() @lru_cache()
def _get_rules_for_room(self, room_id): def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id """Get the current RulesForRoom object for the given room id
@ -275,12 +276,14 @@ class RulesForRoom:
the entire cache for the room. the entire cache for the room.
""" """
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): def __init__(
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
):
""" """
Args: Args:
hs (HomeServer) hs (HomeServer)
room_id (str) room_id (str)
rules_for_room_cache(Cache): The cache object that caches these rules_for_room_cache: The cache object that caches these
RoomsForUser objects. RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric) room_push_rule_cache_metrics (CacheMetric)
""" """
@ -489,13 +492,21 @@ class RulesForRoom:
self.state_group = state_group self.state_group = state_group
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): @attr.attrs(slots=True, frozen=True)
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly, class _Invalidation:
# which namedtuple does for us (i.e. two _CacheContext are the same if # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
# their caches and keys match). This is important in particular to # which means that it it is stored on the bulk_get_push_rules cache entry. In order
# dedupe when we add callbacks to lru cache nodes, otherwise the number # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
# of callbacks would grow. # we need to ensure that two _Invalidation objects are "equal" if they refer to the
# same `cache` and `room_id`.
#
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
# set `frozen=True`.
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
def __call__(self): def __call__(self):
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules: if rules:
rules.invalidate_all() rules.invalidate_all()

View File

@ -13,10 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import functools import functools
import inspect import inspect
import logging import logging
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from typing import (
Any,
Callable,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from twisted.internet import defer from twisted.internet import defer
@ -24,6 +37,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase: class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False): def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig self.orig = orig
arg_spec = inspect.getfullargspec(orig) arg_spec = inspect.getfullargspec(orig)
@ -97,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context self.add_cache_context = cache_context
self.cache_key_builder = get_cache_key_builder(
self.arg_names, self.arg_defaults
)
class CacheDescriptor(_CacheDescriptorBase):
class _LruCachedFunction(Generic[F]):
cache = None # type: LruCache[CacheKey, Any]
__call__ = None # type: F
def lru_cache(
max_entries: int = 1000, cache_context: bool = False,
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.
This is more-or-less a drop-in equivalent to functools.lru_cache, although note
that the signature is slightly different.
The main differences with functools.lru_cache are:
(a) the size of the cache can be controlled via the cache_factor mechanism
(b) the wrapped function can request a "cache_context" which provides a
callback mechanism to indicate that the result is no longer valid
(c) prometheus metrics are exposed automatically.
The function should take zero or more arguments, which are used as the key for the
cache. Single-argument functions use that argument as the cache key; otherwise the
arguments are built into a tuple.
Cached functions can be "chained" (i.e. a cached function can call other cached
functions and get appropriately invalidated when they called caches are
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example:
@lru_cache(cache_context=True)
def foo(self, key, cache_context):
r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
return r1 + r2
The wrapped function also has a 'cache' property which offers direct access to the
underlying LruCache.
"""
def func(orig: F) -> _LruCachedFunction[F]:
desc = LruCacheDescriptor(
orig, max_entries=max_entries, cache_context=cache_context,
)
return cast(_LruCachedFunction[F], desc)
return func
class LruCacheDescriptor(_CacheDescriptorBase):
"""Helper for @lru_cache"""
class _Sentinel(enum.Enum):
sentinel = object()
def __init__(
self, orig, max_entries: int = 1000, cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
cache = LruCache(
cache_name=self.orig.__name__, max_size=self.max_entries,
) # type: LruCache[CacheKey, Any]
get_cache_key = self.cache_key_builder
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
cache_key = get_cache_key(args, kwargs)
ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
if ret != sentinel:
return ret
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
ret2 = self.orig(obj, *args, **kwargs)
cache.set(cache_key, ret2, callbacks=callbacks)
return ret2
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that This caches deferreds, rather than the results themselves. Deferreds that
@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False, cache_context=False,
iterable=False, iterable=False,
): ):
super().__init__(orig, num_args=num_args, cache_context=cache_context) super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries self.max_entries = max_entries
@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable, iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any] ) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs): get_cache_key = self.cache_key_builder
"""Given some args/kwargs return a generator that resolves into
the cache_key.
We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]
def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args, **kwargs): def _wrapped(*args, **kwargs):
@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else: else:
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill wrapped.prefill = cache.prefill
@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped return wrapped
class CacheListDescriptor(_CacheDescriptorBase): class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given a list of keys it looks in the cache to find any hits, then passes
@ -382,11 +459,13 @@ class _CacheContext:
on a lower level. on a lower level.
""" """
Cache = Union[DeferredCache, LruCache]
_cache_context_objects = ( _cache_context_objects = (
WeakValueDictionary() WeakValueDictionary()
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache self._cache = cache
self._cache_key = cache_key self._cache_key = cache_key
@ -396,8 +475,8 @@ class _CacheContext:
@classmethod @classmethod
def get_instance( def get_instance(
cls, cache, cache_key cls, cache: "_CacheContext.Cache", cache_key: CacheKey
): # type: (DeferredCache, CacheKey) -> _CacheContext ) -> "_CacheContext":
"""Returns an instance constructed with the given arguments. """Returns an instance constructed with the given arguments.
A new instance is only created if none already exists. A new instance is only created if none already exists.
@ -418,7 +497,7 @@ def cached(
cache_context: bool = False, cache_context: bool = False,
iterable: bool = False, iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor( func = lambda orig: DeferredCacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
@ -460,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args): def batch_do_something(self, first_arg, second_args):
... ...
""" """
func = lambda orig: CacheListDescriptor( func = lambda orig: DeferredCacheListDescriptor(
orig, orig,
cached_method_name=cached_method_name, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
@ -468,3 +547,65 @@ def cachedList(
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], _CachedFunction[F]], func)
def get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function
Args:
param_names: list of formal parameter names for the cached function
param_defaults: a mapping from parameter name to default value for that param
Returns:
A function which will take an (args, kwargs) pair and return a cache key
"""
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if len(param_names) == 1:
nm = param_names[0]
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return param_defaults[nm]
else:
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
return get_cache_key
def _get_cache_key_gen(
param_names: Iterable[str],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Iterable[Any]:
"""Given some args/kwargs return a generator that resolves into
the cache_key.
This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""
# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield param_defaults[nm]

View File

@ -29,13 +29,46 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest from tests import unittest
from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LruCacheDecoratorTestCase(unittest.TestCase):
def test_base(self):
class Cls:
def __init__(self):
self.mock = mock.Mock()
@lru_cache()
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = "fish"
r = obj.fn(1, 2)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = obj.fn(1, 3)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = obj.fn(1, 2)
self.assertEqual(r, "fish")
r = obj.fn(1, 3)
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def run_on_reactor(): def run_on_reactor():
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(0, d.callback, 0) reactor.callLater(0, d.callback, 0)
@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1) d = obj.fn(1)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
def test_invalidate_cascade(self):
"""Invalidations should cascade up through cache contexts"""
class Cls:
@cached(cache_context=True)
async def func1(self, key, cache_context):
return await self.func2(key, on_invalidate=cache_context.invalidate)
@cached(cache_context=True)
async def func2(self, key, cache_context):
return self.func3(key, on_invalidate=cache_context.invalidate)
@lru_cache(cache_context=True)
def func3(self, key, cache_context):
self.invalidate = cache_context.invalidate
return 42
obj = Cls()
top_invalidate = mock.Mock()
r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
self.assertEqual(r, 42)
obj.invalidate()
top_invalidate.assert_called_once()
class CacheDecoratorTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached """More tests for @cached