Merge pull request #219 from matrix-org/erikj/dictionary_cache
Dictionary and list cachespull/226/head
						commit
						d6bcc68ea7
					
				|  | @ -23,7 +23,7 @@ from synapse.api.errors import ( | |||
|     CodeMessageException, HttpResponseException, SynapseError, | ||||
| ) | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.expiringcache import ExpiringCache | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.events import FrozenEvent | ||||
| import synapse.metrics | ||||
|  |  | |||
|  | @ -229,15 +229,15 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_server(self, server_name, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, None), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def redact_disallowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def redact_disallowed(event, state): | ||||
|             if not state: | ||||
|                 return event | ||||
| 
 | ||||
|  | @ -271,11 +271,10 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|             return event | ||||
| 
 | ||||
|         res = map(redact_disallowed, events_and_states) | ||||
| 
 | ||||
|         logger.info("_filter_events_for_server %r", res) | ||||
| 
 | ||||
|         defer.returnValue(res) | ||||
|         defer.returnValue([ | ||||
|             redact_disallowed(e, event_to_state[e.event_id]) | ||||
|             for e in events | ||||
|         ]) | ||||
| 
 | ||||
|     @log_function | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -503,7 +502,7 @@ class FederationHandler(BaseHandler): | |||
|         event_ids = list(extremities.keys()) | ||||
| 
 | ||||
|         states = yield defer.gatherResults([ | ||||
|             self.state_handler.resolve_state_groups([e]) | ||||
|             self.state_handler.resolve_state_groups(room_id, [e]) | ||||
|             for e in event_ids | ||||
|         ]) | ||||
|         states = dict(zip(event_ids, [s[1] for s in states])) | ||||
|  |  | |||
|  | @ -137,15 +137,15 @@ class MessageHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_client(self, user_id, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_id_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, user_id), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def allowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def allowed(event, state): | ||||
|             if event.type == EventTypes.RoomHistoryVisibility: | ||||
|                 return True | ||||
| 
 | ||||
|  | @ -175,10 +175,10 @@ class MessageHandler(BaseHandler): | |||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|         events_and_states = filter(allowed, events_and_states) | ||||
|         defer.returnValue([ | ||||
|             ev | ||||
|             for ev, _ in events_and_states | ||||
|             event | ||||
|             for event in events | ||||
|             if allowed(event, event_id_to_state[event.event_id]) | ||||
|         ]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -401,10 +401,14 @@ class MessageHandler(BaseHandler): | |||
|             except: | ||||
|                 logger.exception("Failed to get snapshot") | ||||
| 
 | ||||
|         yield defer.gatherResults( | ||||
|             [handle_room(e) for e in room_list], | ||||
|             consumeErrors=True | ||||
|         ).addErrback(unwrapFirstError) | ||||
|         # Only do N rooms at once | ||||
|         n = 5 | ||||
|         d_list = [handle_room(e) for e in room_list] | ||||
|         for i in range(0, len(d_list), n): | ||||
|             yield defer.gatherResults( | ||||
|                 d_list[i:i + n], | ||||
|                 consumeErrors=True | ||||
|             ).addErrback(unwrapFirstError) | ||||
| 
 | ||||
|         ret = { | ||||
|             "rooms": rooms_ret, | ||||
|  |  | |||
|  | @ -294,15 +294,15 @@ class SyncHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_client(self, user_id, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_id_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, user_id), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def allowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def allowed(event, state): | ||||
|             if event.type == EventTypes.RoomHistoryVisibility: | ||||
|                 return True | ||||
| 
 | ||||
|  | @ -331,10 +331,11 @@ class SyncHandler(BaseHandler): | |||
|                 return membership == Membership.INVITE | ||||
| 
 | ||||
|             return True | ||||
|         events_and_states = filter(allowed, events_and_states) | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             ev | ||||
|             for ev, _ in events_and_states | ||||
|             event | ||||
|             for event in events | ||||
|             if allowed(event, event_id_to_state[event.event_id]) | ||||
|         ]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.expiringcache import ExpiringCache | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import AuthError | ||||
| from synapse.api.auth import AuthEventTypes | ||||
|  | @ -96,7 +96,7 @@ class StateHandler(object): | |||
|             cache.ts = self.clock.time_msec() | ||||
|             state = cache.state | ||||
|         else: | ||||
|             res = yield self.resolve_state_groups(event_ids) | ||||
|             res = yield self.resolve_state_groups(room_id, event_ids) | ||||
|             state = res[1] | ||||
| 
 | ||||
|         if event_type: | ||||
|  | @ -155,13 +155,13 @@ class StateHandler(object): | |||
| 
 | ||||
|         if event.is_state(): | ||||
|             ret = yield self.resolve_state_groups( | ||||
|                 [e for e, _ in event.prev_events], | ||||
|                 event.room_id, [e for e, _ in event.prev_events], | ||||
|                 event_type=event.type, | ||||
|                 state_key=event.state_key, | ||||
|             ) | ||||
|         else: | ||||
|             ret = yield self.resolve_state_groups( | ||||
|                 [e for e, _ in event.prev_events], | ||||
|                 event.room_id, [e for e, _ in event.prev_events], | ||||
|             ) | ||||
| 
 | ||||
|         group, curr_state, prev_state = ret | ||||
|  | @ -180,7 +180,7 @@ class StateHandler(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def resolve_state_groups(self, event_ids, event_type=None, state_key=""): | ||||
|     def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): | ||||
|         """ Given a list of event_ids this method fetches the state at each | ||||
|         event, resolves conflicts between them and returns them. | ||||
| 
 | ||||
|  | @ -205,7 +205,7 @@ class StateHandler(object): | |||
|                 ) | ||||
| 
 | ||||
|         state_groups = yield self.store.get_state_groups( | ||||
|             event_ids | ||||
|             room_id, event_ids | ||||
|         ) | ||||
| 
 | ||||
|         logger.debug( | ||||
|  |  | |||
|  | @ -15,25 +15,22 @@ | |||
| import logging | ||||
| 
 | ||||
| from synapse.api.errors import StoreError | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.logcontext import preserve_context_over_fn, LoggingContext | ||||
| from synapse.util.lrucache import LruCache | ||||
| from synapse.util.caches.dictionary_cache import DictionaryCache | ||||
| from synapse.util.caches.descriptors import Cache | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from util.id_generators import IdGenerator, StreamIdGenerator | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from collections import namedtuple, OrderedDict | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| import functools | ||||
| import inspect | ||||
| import sys | ||||
| import time | ||||
| import threading | ||||
| 
 | ||||
| DEBUG_CACHES = False | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -49,208 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") | |||
| sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) | ||||
| sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) | ||||
| 
 | ||||
| caches_by_name = {} | ||||
| cache_counter = metrics.register_cache( | ||||
|     "cache", | ||||
|     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, | ||||
|     labels=["name"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _CacheSentinel = object() | ||||
| 
 | ||||
| 
 | ||||
| class Cache(object): | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000, keylen=1, lru=True): | ||||
|         if lru: | ||||
|             self.cache = LruCache(max_size=max_entries) | ||||
|             self.max_entries = None | ||||
|         else: | ||||
|             self.cache = OrderedDict() | ||||
|             self.max_entries = max_entries | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.keylen = keylen | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, key, default=_CacheSentinel): | ||||
|         val = self.cache.get(key, _CacheSentinel) | ||||
|         if val is not _CacheSentinel: | ||||
|             cache_counter.inc_hits(self.name) | ||||
|             return val | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
| 
 | ||||
|         if default is _CacheSentinel: | ||||
|             raise KeyError() | ||||
|         else: | ||||
|             return default | ||||
| 
 | ||||
|     def update(self, sequence, key, value): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             self.prefill(key, value) | ||||
| 
 | ||||
|     def prefill(self, key, value): | ||||
|         if self.max_entries is not None: | ||||
|             while len(self.cache) >= self.max_entries: | ||||
|                 self.cache.popitem(last=False) | ||||
| 
 | ||||
|         self.cache[key] = value | ||||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
|         if not isinstance(key, tuple): | ||||
|             raise TypeError( | ||||
|                 "The cache key must be a tuple not %r" % (type(key),) | ||||
|             ) | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     This caches deferreds, rather than the results themselves. Deferreds that | ||||
|     fail are removed from the cache. | ||||
| 
 | ||||
|     The function is presumed to take zero or more arguments, which are used in | ||||
|     a tuple as the key for the cache. Hits are served directly from the cache; | ||||
|     misses use the function body to generate the value. | ||||
| 
 | ||||
|     The wrapped function has an additional member, a callable called | ||||
|     "invalidate". This can be used to remove individual entries from the cache. | ||||
| 
 | ||||
|     The wrapped function has another additional callable, called "prefill", | ||||
|     which can be used to insert values into the cache specifically, without | ||||
|     calling the calculation function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, | ||||
|                  inlineCallbacks=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.lru = lru | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args+1] | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         self.cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
|             max_entries=self.max_entries, | ||||
|             keylen=self.num_args, | ||||
|             lru=self.lru, | ||||
|         ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         def wrapped(*args, **kwargs): | ||||
|             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) | ||||
|             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) | ||||
|             try: | ||||
|                 cached_result_d = self.cache.get(cache_key) | ||||
| 
 | ||||
|                 observer = cached_result_d.observe() | ||||
|                 if DEBUG_CACHES: | ||||
|                     @defer.inlineCallbacks | ||||
|                     def check_result(cached_result): | ||||
|                         actual_result = yield self.function_to_call(obj, *args, **kwargs) | ||||
|                         if actual_result != cached_result: | ||||
|                             logger.error( | ||||
|                                 "Stale cache entry %s%r: cached: %r, actual %r", | ||||
|                                 self.orig.__name__, cache_key, | ||||
|                                 cached_result, actual_result, | ||||
|                             ) | ||||
|                             raise ValueError("Stale cache entry") | ||||
|                         defer.returnValue(cached_result) | ||||
|                     observer.addCallback(check_result) | ||||
| 
 | ||||
|                 return observer | ||||
|             except KeyError: | ||||
|                 # Get the sequence number of the cache before reading from the | ||||
|                 # database so that we can tell if the cache is invalidated | ||||
|                 # while the SELECT is executing (SYN-369) | ||||
|                 sequence = self.cache.sequence | ||||
| 
 | ||||
|                 ret = defer.maybeDeferred( | ||||
|                     self.function_to_call, | ||||
|                     obj, *args, **kwargs | ||||
|                 ) | ||||
| 
 | ||||
|                 def onErr(f): | ||||
|                     self.cache.invalidate(cache_key) | ||||
|                     return f | ||||
| 
 | ||||
|                 ret.addErrback(onErr) | ||||
| 
 | ||||
|                 ret = ObservableDeferred(ret, consumeErrors=True) | ||||
|                 self.cache.update(sequence, cache_key, ret) | ||||
| 
 | ||||
|                 return ret.observe() | ||||
| 
 | ||||
|         wrapped.invalidate = self.cache.invalidate | ||||
|         wrapped.invalidate_all = self.cache.invalidate_all | ||||
|         wrapped.prefill = self.cache.prefill | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, lru=True): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru, | ||||
|         inlineCallbacks=True, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class LoggingTransaction(object): | ||||
|     """An object that almost-transparently proxies for the 'txn' object | ||||
|  | @ -372,6 +167,8 @@ class SQLBaseStore(object): | |||
|         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, | ||||
|                                       max_entries=hs.config.event_cache_size) | ||||
| 
 | ||||
|         self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) | ||||
| 
 | ||||
|         self._event_fetch_lock = threading.Condition() | ||||
|         self._event_fetch_list = [] | ||||
|         self._event_fetch_ongoing = 0 | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,7 +15,8 @@ | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| from syutil.base64util import encode_base64 | ||||
| 
 | ||||
| import logging | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from _base import SQLBaseStore, cachedInlineCallbacks | ||||
| from _base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cachedInlineCallbacks | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import logging | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cachedInlineCallbacks | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import StoreError, Codes | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| 
 | ||||
| class RegistrationStore(SQLBaseStore): | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import StoreError | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cachedInlineCallbacks | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| 
 | ||||
| import collections | ||||
| import logging | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from synapse.api.constants import Membership | ||||
| from synapse.types import UserID | ||||
|  |  | |||
|  | @ -13,7 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cachedInlineCallbacks | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import ( | ||||
|     cached, cachedInlineCallbacks, cachedList | ||||
| ) | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  | @ -44,59 +47,25 @@ class StateStore(SQLBaseStore): | |||
|     """ | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_groups(self, event_ids): | ||||
|     def get_state_groups(self, room_id, event_ids): | ||||
|         """ Get the state groups for the given list of event_ids | ||||
| 
 | ||||
|         The return value is a dict mapping group names to lists of events. | ||||
|         """ | ||||
|         if not event_ids: | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         def f(txn): | ||||
|             groups = set() | ||||
|             for event_id in event_ids: | ||||
|                 group = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={"event_id": event_id}, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
|                 if group: | ||||
|                     groups.add(group) | ||||
| 
 | ||||
|             res = {} | ||||
|             for group in groups: | ||||
|                 state_ids = self._simple_select_onecol_txn( | ||||
|                     txn, | ||||
|                     table="state_groups_state", | ||||
|                     keyvalues={"state_group": group}, | ||||
|                     retcol="event_id", | ||||
|                 ) | ||||
| 
 | ||||
|                 res[group] = state_ids | ||||
| 
 | ||||
|             return res | ||||
| 
 | ||||
|         states = yield self.runInteraction( | ||||
|             "get_state_groups", | ||||
|             f, | ||||
|         event_to_groups = yield self._get_state_group_for_events( | ||||
|             room_id, event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         state_list = yield defer.gatherResults( | ||||
|             [ | ||||
|                 self._fetch_events_for_group(group, vals) | ||||
|                 for group, vals in states.items() | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         ) | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups) | ||||
| 
 | ||||
|         defer.returnValue(dict(state_list)) | ||||
| 
 | ||||
|     def _fetch_events_for_group(self, key, events): | ||||
|         return self._get_events( | ||||
|             events, get_prev_content=False | ||||
|         ).addCallback( | ||||
|             lambda evs: (key, evs) | ||||
|         ) | ||||
|         defer.returnValue({ | ||||
|             group: state_map.values() | ||||
|             for group, state_map in group_to_state.items() | ||||
|         }) | ||||
| 
 | ||||
|     def _store_state_groups_txn(self, txn, event, context): | ||||
|         return self._store_mult_state_groups_txn(txn, [(event, context)]) | ||||
|  | @ -204,64 +173,250 @@ class StateStore(SQLBaseStore): | |||
|         events = yield self._get_events(event_ids, get_prev_content=False) | ||||
|         defer.returnValue(events) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_for_events(self, room_id, event_ids): | ||||
|     def _get_state_groups_from_groups(self, groups_and_types): | ||||
|         """Returns dictionary state_group -> state event ids | ||||
| 
 | ||||
|         Args: | ||||
|             groups_and_types (list): list of 2-tuple (`group`, `types`) | ||||
|         """ | ||||
|         def f(txn): | ||||
|             groups = set() | ||||
|             event_to_group = {} | ||||
|             for event_id in event_ids: | ||||
|                 # TODO: Remove this loop. | ||||
|                 group = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={"event_id": event_id}, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
|                 if group: | ||||
|                     event_to_group[event_id] = group | ||||
|                     groups.add(group) | ||||
|             results = {} | ||||
|             for group, types in groups_and_types: | ||||
|                 if types is not None: | ||||
|                     where_clause = "AND (%s)" % ( | ||||
|                         " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), | ||||
|                     ) | ||||
|                 else: | ||||
|                     where_clause = "" | ||||
| 
 | ||||
|             group_to_state_ids = {} | ||||
|             for group in groups: | ||||
|                 state_ids = self._simple_select_onecol_txn( | ||||
|                     txn, | ||||
|                     table="state_groups_state", | ||||
|                     keyvalues={"state_group": group}, | ||||
|                     retcol="event_id", | ||||
|                 ) | ||||
|                 sql = ( | ||||
|                     "SELECT event_id FROM state_groups_state WHERE" | ||||
|                     " state_group = ? %s" | ||||
|                 ) % (where_clause,) | ||||
| 
 | ||||
|                 group_to_state_ids[group] = state_ids | ||||
|                 args = [group] | ||||
|                 if types is not None: | ||||
|                     args.extend([i for typ in types for i in typ]) | ||||
| 
 | ||||
|             return event_to_group, group_to_state_ids | ||||
|                 txn.execute(sql, args) | ||||
| 
 | ||||
|         res = yield self.runInteraction( | ||||
|             "annotate_events_with_state_groups", | ||||
|                 results[group] = [r[0] for r in txn.fetchall()] | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "_get_state_groups_from_groups", | ||||
|             f, | ||||
|         ) | ||||
| 
 | ||||
|         event_to_group, group_to_state_ids = res | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_for_events(self, room_id, event_ids, types): | ||||
|         """Given a list of event_ids and type tuples, return a list of state | ||||
|         dicts for each event. The state dicts will only have the type/state_keys | ||||
|         that are in the `types` list. | ||||
| 
 | ||||
|         state_list = yield defer.gatherResults( | ||||
|             [ | ||||
|                 self._fetch_events_for_group(group, vals) | ||||
|                 for group, vals in group_to_state_ids.items() | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             event_ids (list) | ||||
|             types (list): List of (type, state_key) tuples which are used to | ||||
|                 filter the state fetched. `state_key` may be None, which matches | ||||
|                 any `state_key` | ||||
| 
 | ||||
|         Returns: | ||||
|             deferred: A list of dicts corresponding to the event_ids given. | ||||
|             The dicts are mappings from (type, state_key) -> state_events | ||||
|         """ | ||||
|         event_to_groups = yield self._get_state_group_for_events( | ||||
|             room_id, event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         state_dict = { | ||||
|             group: { | ||||
|                 (ev.type, ev.state_key): ev | ||||
|                 for ev in state | ||||
|             } | ||||
|             for group, state in state_list | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups, types) | ||||
| 
 | ||||
|         event_to_state = { | ||||
|             event_id: group_to_state[group] | ||||
|             for event_id, group in event_to_groups.items() | ||||
|         } | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             state_dict.get(event_to_group.get(event, None), None) | ||||
|             for event in event_ids | ||||
|         ]) | ||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||
| 
 | ||||
|     @cached(num_args=2, lru=True, max_entries=100000) | ||||
|     def _get_state_group_for_event(self, room_id, event_id): | ||||
|         return self._simple_select_one_onecol( | ||||
|             table="event_to_state_groups", | ||||
|             keyvalues={ | ||||
|                 "event_id": event_id, | ||||
|             }, | ||||
|             retcol="state_group", | ||||
|             allow_none=True, | ||||
|             desc="_get_state_group_for_event", | ||||
|         ) | ||||
| 
 | ||||
|     @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", | ||||
|                 num_args=2) | ||||
|     def _get_state_group_for_events(self, room_id, event_ids): | ||||
|         """Returns mapping event_id -> state_group | ||||
|         """ | ||||
|         def f(txn): | ||||
|             results = {} | ||||
|             for event_id in event_ids: | ||||
|                 results[event_id] = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={ | ||||
|                         "event_id": event_id, | ||||
|                     }, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction("_get_state_group_for_events", f) | ||||
| 
 | ||||
|     def _get_some_state_from_cache(self, group, types): | ||||
|         """Checks if group is in cache. See `_get_state_for_groups` | ||||
| 
 | ||||
|         Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). | ||||
|         `missing_types` is the list of types that aren't in the cache for that | ||||
|         group. `got_all` is a bool indicating if we successfully retrieved all | ||||
|         requests state from the cache, if False we need to query the DB for the | ||||
|         missing state. | ||||
| 
 | ||||
|         Args: | ||||
|             group: The state group to lookup | ||||
|             types (list): List of 2-tuples of the form (`type`, `state_key`), | ||||
|                 where a `state_key` of `None` matches all state_keys for the | ||||
|                 `type`. | ||||
|         """ | ||||
|         is_all, state_dict = self._state_group_cache.get(group) | ||||
| 
 | ||||
|         type_to_key = {} | ||||
|         missing_types = set() | ||||
|         for typ, state_key in types: | ||||
|             if state_key is None: | ||||
|                 type_to_key[typ] = None | ||||
|                 missing_types.add((typ, state_key)) | ||||
|             else: | ||||
|                 if type_to_key.get(typ, object()) is not None: | ||||
|                     type_to_key.setdefault(typ, set()).add(state_key) | ||||
| 
 | ||||
|                 if (typ, state_key) not in state_dict: | ||||
|                     missing_types.add((typ, state_key)) | ||||
| 
 | ||||
|         sentinel = object() | ||||
| 
 | ||||
|         def include(typ, state_key): | ||||
|             valid_state_keys = type_to_key.get(typ, sentinel) | ||||
|             if valid_state_keys is sentinel: | ||||
|                 return False | ||||
|             if valid_state_keys is None: | ||||
|                 return True | ||||
|             if state_key in valid_state_keys: | ||||
|                 return True | ||||
|             return False | ||||
| 
 | ||||
|         got_all = not (missing_types or types is None) | ||||
| 
 | ||||
|         return { | ||||
|             k: v for k, v in state_dict.items() | ||||
|             if include(k[0], k[1]) | ||||
|         }, missing_types, got_all | ||||
| 
 | ||||
|     def _get_all_state_from_cache(self, group): | ||||
|         """Checks if group is in cache. See `_get_state_for_groups` | ||||
| 
 | ||||
|         Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool | ||||
|         indicating if we successfully retrieved all requests state from the | ||||
|         cache, if False we need to query the DB for the missing state. | ||||
| 
 | ||||
|         Args: | ||||
|             group: The state group to lookup | ||||
|         """ | ||||
|         is_all, state_dict = self._state_group_cache.get(group) | ||||
|         return state_dict, is_all | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_state_for_groups(self, groups, types=None): | ||||
|         """Given list of groups returns dict of group -> list of state events | ||||
|         with matching types. `types` is a list of `(type, state_key)`, where | ||||
|         a `state_key` of None matches all state_keys. If `types` is None then | ||||
|         all events are returned. | ||||
|         """ | ||||
|         results = {} | ||||
|         missing_groups_and_types = [] | ||||
|         if types is not None: | ||||
|             for group in set(groups): | ||||
|                 state_dict, missing_types, got_all = self._get_some_state_from_cache( | ||||
|                     group, types | ||||
|                 ) | ||||
|                 results[group] = state_dict | ||||
| 
 | ||||
|                 if not got_all: | ||||
|                     missing_groups_and_types.append((group, missing_types)) | ||||
|         else: | ||||
|             for group in set(groups): | ||||
|                 state_dict, got_all = self._get_all_state_from_cache( | ||||
|                     group | ||||
|                 ) | ||||
|                 results[group] = state_dict | ||||
| 
 | ||||
|                 if not got_all: | ||||
|                     missing_groups_and_types.append((group, None)) | ||||
| 
 | ||||
|         if not missing_groups_and_types: | ||||
|             defer.returnValue({ | ||||
|                 group: { | ||||
|                     type_tuple: event | ||||
|                     for type_tuple, event in state.items() | ||||
|                     if event | ||||
|                 } | ||||
|                 for group, state in results.items() | ||||
|             }) | ||||
| 
 | ||||
|         # Okay, so we have some missing_types, lets fetch them. | ||||
|         cache_seq_num = self._state_group_cache.sequence | ||||
| 
 | ||||
|         group_state_dict = yield self._get_state_groups_from_groups( | ||||
|             missing_groups_and_types | ||||
|         ) | ||||
| 
 | ||||
|         state_events = yield self._get_events( | ||||
|             [e_id for l in group_state_dict.values() for e_id in l], | ||||
|             get_prev_content=False | ||||
|         ) | ||||
| 
 | ||||
|         state_events = {e.event_id: e for e in state_events} | ||||
| 
 | ||||
|         # Now we want to update the cache with all the things we fetched | ||||
|         # from the database. | ||||
|         for group, state_ids in group_state_dict.items(): | ||||
|             if types: | ||||
|                 # We delibrately put key -> None mappings into the cache to | ||||
|                 # cache absence of the key, on the assumption that if we've | ||||
|                 # explicitly asked for some types then we will probably ask | ||||
|                 # for them again. | ||||
|                 state_dict = {key: None for key in types} | ||||
|                 state_dict.update(results[group]) | ||||
|             else: | ||||
|                 state_dict = results[group] | ||||
| 
 | ||||
|             for event_id in state_ids: | ||||
|                 state_event = state_events[event_id] | ||||
|                 state_dict[(state_event.type, state_event.state_key)] = state_event | ||||
| 
 | ||||
|             self._state_group_cache.update( | ||||
|                 cache_seq_num, | ||||
|                 key=group, | ||||
|                 value=state_dict, | ||||
|                 full=(types is None), | ||||
|             ) | ||||
| 
 | ||||
|             results[group].update({ | ||||
|                 key: value for key, value in state_dict.items() if value | ||||
|             }) | ||||
| 
 | ||||
|         defer.returnValue(results) | ||||
| 
 | ||||
| 
 | ||||
| def _make_group_id(clock): | ||||
|  |  | |||
|  | @ -36,6 +36,7 @@ what sort order was used: | |||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.types import RoomStreamToken | ||||
| from synapse.util.logutils import log_function | ||||
|  | @ -299,9 +300,8 @@ class StreamStore(SQLBaseStore): | |||
| 
 | ||||
|         defer.returnValue((events, token)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_recent_events_for_room(self, room_id, limit, end_token, | ||||
|                                    with_feedback=False, from_token=None): | ||||
|     @cachedInlineCallbacks(num_args=4) | ||||
|     def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): | ||||
|         # TODO (erikj): Handle compressed feedback | ||||
| 
 | ||||
|         end_token = RoomStreamToken.parse_stream_token(end_token) | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import synapse.metrics | ||||
| 
 | ||||
| DEBUG_CACHES = False | ||||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for("synapse.util.caches") | ||||
| 
 | ||||
| caches_by_name = {} | ||||
| cache_counter = metrics.register_cache( | ||||
|     "cache", | ||||
|     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, | ||||
|     labels=["name"], | ||||
| ) | ||||
|  | @ -0,0 +1,377 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| 
 | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| 
 | ||||
| from . import caches_by_name, DEBUG_CACHES, cache_counter | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from collections import OrderedDict | ||||
| 
 | ||||
| import functools | ||||
| import inspect | ||||
| import threading | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| _CacheSentinel = object() | ||||
| 
 | ||||
| 
 | ||||
| class Cache(object): | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000, keylen=1, lru=True): | ||||
|         if lru: | ||||
|             self.cache = LruCache(max_size=max_entries) | ||||
|             self.max_entries = None | ||||
|         else: | ||||
|             self.cache = OrderedDict() | ||||
|             self.max_entries = max_entries | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.keylen = keylen | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, key, default=_CacheSentinel): | ||||
|         val = self.cache.get(key, _CacheSentinel) | ||||
|         if val is not _CacheSentinel: | ||||
|             cache_counter.inc_hits(self.name) | ||||
|             return val | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
| 
 | ||||
|         if default is _CacheSentinel: | ||||
|             raise KeyError() | ||||
|         else: | ||||
|             return default | ||||
| 
 | ||||
|     def update(self, sequence, key, value): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             self.prefill(key, value) | ||||
| 
 | ||||
|     def prefill(self, key, value): | ||||
|         if self.max_entries is not None: | ||||
|             while len(self.cache) >= self.max_entries: | ||||
|                 self.cache.popitem(last=False) | ||||
| 
 | ||||
|         self.cache[key] = value | ||||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
|         if not isinstance(key, tuple): | ||||
|             raise TypeError( | ||||
|                 "The cache key must be a tuple not %r" % (type(key),) | ||||
|             ) | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     This caches deferreds, rather than the results themselves. Deferreds that | ||||
|     fail are removed from the cache. | ||||
| 
 | ||||
|     The function is presumed to take zero or more arguments, which are used in | ||||
|     a tuple as the key for the cache. Hits are served directly from the cache; | ||||
|     misses use the function body to generate the value. | ||||
| 
 | ||||
|     The wrapped function has an additional member, a callable called | ||||
|     "invalidate". This can be used to remove individual entries from the cache. | ||||
| 
 | ||||
|     The wrapped function has another additional callable, called "prefill", | ||||
|     which can be used to insert values into the cache specifically, without | ||||
|     calling the calculation function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, | ||||
|                  inlineCallbacks=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.lru = lru | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args+1] | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         self.cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
|             max_entries=self.max_entries, | ||||
|             keylen=self.num_args, | ||||
|             lru=self.lru, | ||||
|         ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         def wrapped(*args, **kwargs): | ||||
|             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) | ||||
|             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) | ||||
|             try: | ||||
|                 cached_result_d = self.cache.get(cache_key) | ||||
| 
 | ||||
|                 observer = cached_result_d.observe() | ||||
|                 if DEBUG_CACHES: | ||||
|                     @defer.inlineCallbacks | ||||
|                     def check_result(cached_result): | ||||
|                         actual_result = yield self.function_to_call(obj, *args, **kwargs) | ||||
|                         if actual_result != cached_result: | ||||
|                             logger.error( | ||||
|                                 "Stale cache entry %s%r: cached: %r, actual %r", | ||||
|                                 self.orig.__name__, cache_key, | ||||
|                                 cached_result, actual_result, | ||||
|                             ) | ||||
|                             raise ValueError("Stale cache entry") | ||||
|                         defer.returnValue(cached_result) | ||||
|                     observer.addCallback(check_result) | ||||
| 
 | ||||
|                 return observer | ||||
|             except KeyError: | ||||
|                 # Get the sequence number of the cache before reading from the | ||||
|                 # database so that we can tell if the cache is invalidated | ||||
|                 # while the SELECT is executing (SYN-369) | ||||
|                 sequence = self.cache.sequence | ||||
| 
 | ||||
|                 ret = defer.maybeDeferred( | ||||
|                     self.function_to_call, | ||||
|                     obj, *args, **kwargs | ||||
|                 ) | ||||
| 
 | ||||
|                 def onErr(f): | ||||
|                     self.cache.invalidate(cache_key) | ||||
|                     return f | ||||
| 
 | ||||
|                 ret.addErrback(onErr) | ||||
| 
 | ||||
|                 ret = ObservableDeferred(ret, consumeErrors=True) | ||||
|                 self.cache.update(sequence, cache_key, ret) | ||||
| 
 | ||||
|                 return ret.observe() | ||||
| 
 | ||||
|         wrapped.invalidate = self.cache.invalidate | ||||
|         wrapped.invalidate_all = self.cache.invalidate_all | ||||
|         wrapped.prefill = self.cache.prefill | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| class CacheListDescriptor(object): | ||||
|     """Wraps an existing cache to support bulk fetching of keys. | ||||
| 
 | ||||
|     Given a list of keys it looks in the cache to find any hits, then passes | ||||
|     the list of missing keys to the wrapped fucntion. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): | ||||
|         """ | ||||
|         Args: | ||||
|             orig (function) | ||||
|             cache (Cache) | ||||
|             list_name (str): Name of the argument which is the bulk lookup list | ||||
|             num_args (int) | ||||
|             inlineCallbacks (bool): Whether orig is a generator that should | ||||
|                 be wrapped by defer.inlineCallbacks | ||||
|         """ | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.list_name = list_name | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args+1] | ||||
|         self.list_pos = self.arg_names.index(self.list_name) | ||||
| 
 | ||||
|         self.cache = cache | ||||
| 
 | ||||
|         self.sentinel = object() | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         if self.list_name not in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "Couldn't see arguments %r for %r." | ||||
|                 % (self.list_name, cache.name,) | ||||
|             ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         def wrapped(*args, **kwargs): | ||||
|             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) | ||||
|             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] | ||||
|             list_args = arg_dict[self.list_name] | ||||
| 
 | ||||
|             # cached is a dict arg -> deferred, where deferred results in a | ||||
|             # 2-tuple (`arg`, `result`) | ||||
|             cached = {} | ||||
|             missing = [] | ||||
|             for arg in list_args: | ||||
|                 key = list(keyargs) | ||||
|                 key[self.list_pos] = arg | ||||
| 
 | ||||
|                 try: | ||||
|                     res = self.cache.get(tuple(key)).observe() | ||||
|                     res.addCallback(lambda r, arg: (arg, r), arg) | ||||
|                     cached[arg] = res | ||||
|                 except KeyError: | ||||
|                     missing.append(arg) | ||||
| 
 | ||||
|             if missing: | ||||
|                 sequence = self.cache.sequence | ||||
|                 args_to_call = dict(arg_dict) | ||||
|                 args_to_call[self.list_name] = missing | ||||
| 
 | ||||
|                 ret_d = defer.maybeDeferred( | ||||
|                     self.function_to_call, | ||||
|                     **args_to_call | ||||
|                 ) | ||||
| 
 | ||||
|                 ret_d = ObservableDeferred(ret_d) | ||||
| 
 | ||||
|                 # We need to create deferreds for each arg in the list so that | ||||
|                 # we can insert the new deferred into the cache. | ||||
|                 for arg in missing: | ||||
|                     observer = ret_d.observe() | ||||
|                     observer.addCallback(lambda r, arg: r[arg], arg) | ||||
| 
 | ||||
|                     observer = ObservableDeferred(observer) | ||||
| 
 | ||||
|                     key = list(keyargs) | ||||
|                     key[self.list_pos] = arg | ||||
|                     self.cache.update(sequence, tuple(key), observer) | ||||
| 
 | ||||
|                     def invalidate(f, key): | ||||
|                         self.cache.invalidate(key) | ||||
|                         return f | ||||
|                     observer.addErrback(invalidate, tuple(key)) | ||||
| 
 | ||||
|                     res = observer.observe() | ||||
|                     res.addCallback(lambda r, arg: (arg, r), arg) | ||||
| 
 | ||||
|                     cached[arg] = res | ||||
| 
 | ||||
|             return defer.gatherResults( | ||||
|                 cached.values(), | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, lru=True): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru, | ||||
|         inlineCallbacks=True, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): | ||||
|     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. | ||||
| 
 | ||||
|     Used to do batch lookups for an already created cache. A single argument | ||||
|     is specified as a list that is iterated through to lookup keys in the | ||||
|     original cache. A new list consisting of the keys that weren't in the cache | ||||
|     get passed to the original function, the result of which is stored in the | ||||
|     cache. | ||||
| 
 | ||||
|     Args: | ||||
|         cache (Cache): The underlying cache to use. | ||||
|         list_name (str): The name of the argument that is the list to use to | ||||
|             do batch lookups in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache. | ||||
|         inlineCallbacks (bool): Should the function be wrapped in an | ||||
|             `defer.inlineCallbacks`? | ||||
| 
 | ||||
|     Example: | ||||
| 
 | ||||
|         class Example(object): | ||||
|             @cached(num_args=2) | ||||
|             def do_something(self, first_arg): | ||||
|                 ... | ||||
| 
 | ||||
|             @cachedList(do_something.cache, list_name="second_args", num_args=2) | ||||
|             def batch_do_something(self, first_arg, second_args): | ||||
|                 ... | ||||
|     """ | ||||
|     return lambda orig: CacheListDescriptor( | ||||
|         orig, | ||||
|         cache=cache, | ||||
|         list_name=list_name, | ||||
|         num_args=num_args, | ||||
|         inlineCallbacks=inlineCallbacks, | ||||
|     ) | ||||
|  | @ -0,0 +1,103 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| from collections import namedtuple | ||||
| from . import caches_by_name, cache_counter | ||||
| import threading | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) | ||||
| 
 | ||||
| 
 | ||||
| class DictionaryCache(object): | ||||
|     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. | ||||
|     fetching a subset of dictionary keys for a particular key. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000): | ||||
|         self.cache = LruCache(max_size=max_entries) | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         # caches_by_name[name] = self.cache | ||||
| 
 | ||||
|         class Sentinel(object): | ||||
|             __slots__ = [] | ||||
| 
 | ||||
|         self.sentinel = Sentinel() | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, key, dict_keys=None): | ||||
|         entry = self.cache.get(key, self.sentinel) | ||||
|         if entry is not self.sentinel: | ||||
|             cache_counter.inc_hits(self.name) | ||||
| 
 | ||||
|             if dict_keys is None: | ||||
|                 return DictionaryEntry(entry.full, dict(entry.value)) | ||||
|             else: | ||||
|                 return DictionaryEntry(entry.full, { | ||||
|                     k: entry.value[k] | ||||
|                     for k in dict_keys | ||||
|                     if k in entry.value | ||||
|                 }) | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
|         return DictionaryEntry(False, {}) | ||||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
|     def update(self, sequence, key, value, full=False): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             if full: | ||||
|                 self._insert(key, value) | ||||
|             else: | ||||
|                 self._update_or_insert(key, value) | ||||
| 
 | ||||
|     def _update_or_insert(self, key, value): | ||||
|         entry = self.cache.setdefault(key, DictionaryEntry(False, {})) | ||||
|         entry.value.update(value) | ||||
| 
 | ||||
|     def _insert(self, key, value): | ||||
|         self.cache[key] = DictionaryEntry(True, value) | ||||
|  | @ -19,7 +19,7 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.util.async import ObservableDeferred | ||||
| 
 | ||||
| from synapse.storage._base import Cache, cached | ||||
| from synapse.util.caches.descriptors import Cache, cached | ||||
| 
 | ||||
| 
 | ||||
| class CacheTestCase(unittest.TestCase): | ||||
|  |  | |||
|  | @ -69,7 +69,7 @@ class StateGroupStore(object): | |||
| 
 | ||||
|         self._next_group = 1 | ||||
| 
 | ||||
|     def get_state_groups(self, event_ids): | ||||
|     def get_state_groups(self, room_id, event_ids): | ||||
|         groups = {} | ||||
|         for event_id in event_ids: | ||||
|             group = self._event_to_state_group.get(event_id) | ||||
|  |  | |||
|  | @ -0,0 +1,101 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from tests import unittest | ||||
| 
 | ||||
| from synapse.util.caches.dictionary_cache import DictionaryCache | ||||
| 
 | ||||
| 
 | ||||
| class DictCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.cache = DictionaryCache("foobar") | ||||
| 
 | ||||
|     def test_simple_cache_hit_full(self): | ||||
|         key = "test_simple_cache_hit_full" | ||||
| 
 | ||||
|         v = self.cache.get(key) | ||||
|         self.assertEqual((False, {}), v) | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = {"test": "test_simple_cache_hit_full"} | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key) | ||||
|         self.assertEqual(test_value, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_hit_partial(self): | ||||
|         key = "test_simple_cache_hit_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_hit_partial" | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test"]) | ||||
|         self.assertEqual(test_value, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_miss_partial(self): | ||||
|         key = "test_simple_cache_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_miss_partial" | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test2"]) | ||||
|         self.assertEqual({}, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_hit_miss_partial(self): | ||||
|         key = "test_simple_cache_hit_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_hit_miss_partial", | ||||
|             "test2": "test_simple_cache_hit_miss_partial2", | ||||
|             "test3": "test_simple_cache_hit_miss_partial3", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test2"]) | ||||
|         self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) | ||||
| 
 | ||||
|     def test_multi_insert(self): | ||||
|         key = "test_simple_cache_hit_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value_1 = { | ||||
|             "test": "test_simple_cache_hit_miss_partial", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value_1, full=False) | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value_2 = { | ||||
|             "test2": "test_simple_cache_hit_miss_partial2", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value_2, full=False) | ||||
| 
 | ||||
|         c = self.cache.get(key) | ||||
|         self.assertEqual( | ||||
|             { | ||||
|                 "test": "test_simple_cache_hit_miss_partial", | ||||
|                 "test2": "test_simple_cache_hit_miss_partial2", | ||||
|             }, | ||||
|             c.value | ||||
|         ) | ||||
|  | @ -16,7 +16,7 @@ | |||
| 
 | ||||
| from .. import unittest | ||||
| 
 | ||||
| from synapse.util.lrucache import LruCache | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| 
 | ||||
| class LruCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|  | @ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): | |||
|         cache["key"] = 1 | ||||
|         self.assertEquals(cache.pop("key"), 1) | ||||
|         self.assertEquals(cache.pop("key"), None) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston