Rename to on_invalidate

pull/1030/head
Erik Johnston 2016-08-19 15:13:58 +01:00
parent dc76a3e909
commit c0d7d9d642
3 changed files with 15 additions and 21 deletions

View File

@ -156,14 +156,14 @@ class PushRuleStore(SQLBaseStore):
# users in the room who have pushers need to get push rules run because # users in the room who have pushers need to get push rules run because
# that's how their pushers work # that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers( if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, cache_context=cache_context, local_users_in_room, on_invalidate=cache_context.invalidate,
) )
user_ids = set( user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
) )
users_with_receipts = yield self.get_users_with_read_receipts_in_room( users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, cache_context=cache_context, room_id, on_invalidate=cache_context.invalidate,
) )
# any users with pushers must be ours: they have pushers # any users with pushers must be ours: they have pushers
@ -172,7 +172,7 @@ class PushRuleStore(SQLBaseStore):
user_ids.add(uid) user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules( rules_by_user = yield self.bulk_get_push_rules(
user_ids, cache_context=cache_context user_ids, on_invalidate=cache_context.invalidate,
) )
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}

View File

@ -148,8 +148,8 @@ class CacheDescriptor(object):
@cachedInlineCallbacks(cache_context=True) @cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context): def foo(self, key, cache_context):
r1 = yield self.bar1(key, cache_context=cache_context) r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, cache_context=cache_context) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2) defer.returnValue(r1 + r2)
""" """
@ -208,11 +208,7 @@ class CacheDescriptor(object):
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
cache_context = kwargs.pop("cache_context", None) invalidate_callback = kwargs.pop("on_invalidate", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
# Add our own `cache_context` to argument list if the wrapped function # Add our own `cache_context` to argument list if the wrapped function
# has asked for one # has asked for one
@ -226,7 +222,7 @@ class CacheDescriptor(object):
self_context.key = cache_key self_context.key = cache_key
try: try:
cached_result_d = cache.get(cache_key, callback=context_callback) cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -263,7 +259,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=context_callback) cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -332,11 +328,9 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
cache_context = kwargs.pop("cache_context", None) # If we're passed a cache_context then we'll want to call its invalidate()
if cache_context: # whenever we are invalidated
context_callback = cache_context.invalidate invalidate_callback = kwargs.pop("on_invalidate", None)
else:
context_callback = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
@ -352,7 +346,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key), callback=context_callback) res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded(): if not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
@ -388,7 +382,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update( cache.update(
sequence, tuple(key), observer, sequence, tuple(key), observer,
callback=context_callback callback=invalidate_callback
) )
def invalidate(f, key): def invalidate(f, key):

View File

@ -214,7 +214,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
@cached(cache_context=True) @cached(cache_context=True)
def func2(self, key, cache_context): def func2(self, key, cache_context):
callcount2[0] += 1 callcount2[0] += 1
return self.func(key, cache_context=cache_context) return self.func(key, on_invalidate=cache_context.invalidate)
a = A() a = A()
yield a.func2("foo") yield a.func2("foo")
@ -247,7 +247,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
@cached(cache_context=True) @cached(cache_context=True)
def func2(self, key, cache_context): def func2(self, key, cache_context):
callcount2[0] += 1 callcount2[0] += 1
return self.func(key, cache_context=cache_context) return self.func(key, on_invalidate=cache_context.invalidate)
a = A() a = A()
yield a.func2("foo") yield a.func2("foo")