Opt out of cache expiry for `get_users_who_share_room_with_user` (#10826)

* Allow LruCaches to opt out of time-based expiry
* Don't expire `get_users_who_share_room` & friends
pull/10889/head
David Robertson 2021-09-22 14:21:58 +01:00 committed by GitHub
parent 80828eda06
commit 724aef9a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 30 additions and 6 deletions

2
changelog.d/10826.misc Normal file
View File

@ -0,0 +1,2 @@
Opt out of cache expiry for `get_users_who_share_room_with_user`, to hopefully improve `/sync` performance when you
haven't synced recently.

View File

@ -162,7 +162,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn, self._check_safe_current_state_events_membership_updated_txn,
) )
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
async def get_users_in_room(self, room_id: str) -> List[str]: async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id "get_users_in_room", self.get_users_in_room_txn, room_id
@ -439,7 +439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id") return results_dict.get("membership"), results_dict.get("event_id")
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
async def get_rooms_for_user_with_stream_ordering( async def get_rooms_for_user_with_stream_ordering(
self, user_id: str self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@ -544,7 +544,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
return frozenset(r.room_id for r in rooms) return frozenset(r.room_id for r in rooms)
@cached(max_entries=500000, cache_context=True, iterable=True) @cached(
max_entries=500000,
cache_context=True,
iterable=True,
prune_unread_entries=False,
)
async def get_users_who_share_room_with_user( async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext self, user_id: str, cache_context: _CacheContext
) -> Set[str]: ) -> Set[str]:

View File

@ -73,6 +73,7 @@ class DeferredCache(Generic[KT, VT]):
tree: bool = False, tree: bool = False,
iterable: bool = False, iterable: bool = False,
apply_cache_factor_from_config: bool = True, apply_cache_factor_from_config: bool = True,
prune_unread_entries: bool = True,
): ):
""" """
Args: Args:
@ -105,6 +106,7 @@ class DeferredCache(Generic[KT, VT]):
size_callback=(lambda d: len(d) or 1) if iterable else None, size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb, metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config, apply_cache_factor_from_config=apply_cache_factor_from_config,
prune_unread_entries=prune_unread_entries,
) )
self.thread: Optional[threading.Thread] = None self.thread: Optional[threading.Thread] = None

View File

@ -258,6 +258,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
tree=False, tree=False,
cache_context=False, cache_context=False,
iterable=False, iterable=False,
prune_unread_entries: bool = True,
): ):
super().__init__(orig, num_args=num_args, cache_context=cache_context) super().__init__(orig, num_args=num_args, cache_context=cache_context)
@ -269,6 +270,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.max_entries = max_entries self.max_entries = max_entries
self.tree = tree self.tree = tree
self.iterable = iterable self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
def __get__(self, obj, owner): def __get__(self, obj, owner):
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
@ -276,6 +278,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
max_entries=self.max_entries, max_entries=self.max_entries,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
prune_unread_entries=self.prune_unread_entries,
) )
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
@ -507,6 +510,7 @@ def cached(
tree: bool = False, tree: bool = False,
cache_context: bool = False, cache_context: bool = False,
iterable: bool = False, iterable: bool = False,
prune_unread_entries: bool = True,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor( func = lambda orig: DeferredCacheDescriptor(
orig, orig,
@ -515,6 +519,7 @@ def cached(
tree=tree, tree=tree,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable, iterable=iterable,
prune_unread_entries=prune_unread_entries,
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], _CachedFunction[F]], func)

View File

@ -202,10 +202,11 @@ class _Node:
cache: "weakref.ReferenceType[LruCache]", cache: "weakref.ReferenceType[LruCache]",
clock: Clock, clock: Clock,
callbacks: Collection[Callable[[], None]] = (), callbacks: Collection[Callable[[], None]] = (),
prune_unread_entries: bool = True,
): ):
self._list_node = ListNode.insert_after(self, root) self._list_node = ListNode.insert_after(self, root)
self._global_list_node = None self._global_list_node: Optional[_TimedListNode] = None
if USE_GLOBAL_LIST: if USE_GLOBAL_LIST and prune_unread_entries:
self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
self._global_list_node.update_last_access(clock) self._global_list_node.update_last_access(clock)
@ -314,6 +315,7 @@ class LruCache(Generic[KT, VT]):
metrics_collection_callback: Optional[Callable[[], None]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True, apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None, clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
): ):
""" """
Args: Args:
@ -427,7 +429,15 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len) self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) node = _Node(
list_root,
key,
value,
weak_ref_to_self,
real_clock,
callbacks,
prune_unread_entries,
)
cache[key] = node cache[key] = node
if size_callback: if size_callback: