Only store data in caches, not "smart" objects (#9845)
							parent
							
								
									51a20914a8
								
							
						
					
					
						commit
						3853a7edfc
					
				|  | @ -0,0 +1 @@ | |||
| Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores. | ||||
|  | @ -106,6 +106,10 @@ class BulkPushRuleEvaluator: | |||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|         # Used by `RulesForRoom` to ensure only one thing mutates the cache at a | ||||
|         # time. Keyed off room_id. | ||||
|         self._rules_linearizer = Linearizer(name="rules_for_room") | ||||
| 
 | ||||
|         self.room_push_rule_cache_metrics = register_cache( | ||||
|             "cache", | ||||
|             "room_push_rule_cache", | ||||
|  | @ -123,7 +127,16 @@ class BulkPushRuleEvaluator: | |||
|             dict of user_id -> push_rules | ||||
|         """ | ||||
|         room_id = event.room_id | ||||
|         rules_for_room = self._get_rules_for_room(room_id) | ||||
| 
 | ||||
|         rules_for_room_data = self._get_rules_for_room(room_id) | ||||
|         rules_for_room = RulesForRoom( | ||||
|             hs=self.hs, | ||||
|             room_id=room_id, | ||||
|             rules_for_room_cache=self._get_rules_for_room.cache, | ||||
|             room_push_rule_cache_metrics=self.room_push_rule_cache_metrics, | ||||
|             linearizer=self._rules_linearizer, | ||||
|             cached_data=rules_for_room_data, | ||||
|         ) | ||||
| 
 | ||||
|         rules_by_user = await rules_for_room.get_rules(event, context) | ||||
| 
 | ||||
|  | @ -142,17 +155,12 @@ class BulkPushRuleEvaluator: | |||
|         return rules_by_user | ||||
| 
 | ||||
|     @lru_cache() | ||||
|     def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": | ||||
|         """Get the current RulesForRoom object for the given room id""" | ||||
|         # It's important that RulesForRoom gets added to self._get_rules_for_room.cache | ||||
|     def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData": | ||||
|         """Get the current RulesForRoomData object for the given room id""" | ||||
|         # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache | ||||
|         # before any lookup methods get called on it as otherwise there may be | ||||
|         # a race if invalidate_all gets called (which assumes its in the cache) | ||||
|         return RulesForRoom( | ||||
|             self.hs, | ||||
|             room_id, | ||||
|             self._get_rules_for_room.cache, | ||||
|             self.room_push_rule_cache_metrics, | ||||
|         ) | ||||
|         return RulesForRoomData() | ||||
| 
 | ||||
|     async def _get_power_levels_and_sender_level( | ||||
|         self, event: EventBase, context: EventContext | ||||
|  | @ -282,11 +290,49 @@ def _condition_checker( | |||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True) | ||||
| class RulesForRoomData: | ||||
|     """The data stored in the cache by `RulesForRoom`. | ||||
| 
 | ||||
|     We don't store `RulesForRoom` directly in the cache as we want our caches to | ||||
|     *only* include data, and not references to e.g. the data stores. | ||||
|     """ | ||||
| 
 | ||||
|     # event_id -> (user_id, state) | ||||
|     member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict) | ||||
|     # user_id -> rules | ||||
|     rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict) | ||||
| 
 | ||||
|     # The last state group we updated the caches for. If the state_group of | ||||
|     # a new event comes along, we know that we can just return the cached | ||||
|     # result. | ||||
|     # On invalidation of the rules themselves (if the user changes them), | ||||
|     # we invalidate everything and set state_group to `object()` | ||||
|     state_group = attr.ib(type=Union[object, int], factory=object) | ||||
| 
 | ||||
|     # A sequence number to keep track of when we're allowed to update the | ||||
|     # cache. We bump the sequence number when we invalidate the cache. If | ||||
|     # the sequence number changes while we're calculating stuff we should | ||||
|     # not update the cache with it. | ||||
|     sequence = attr.ib(type=int, default=0) | ||||
| 
 | ||||
|     # A cache of user_ids that we *know* aren't interesting, e.g. user_ids | ||||
|     # owned by AS's, or remote users, etc. (I.e. users we will never need to | ||||
|     # calculate push for) | ||||
|     # These never need to be invalidated as we will never set up push for | ||||
|     # them. | ||||
|     uninteresting_user_set = attr.ib(type=Set[str], factory=set) | ||||
| 
 | ||||
| 
 | ||||
| class RulesForRoom: | ||||
|     """Caches push rules for users in a room. | ||||
| 
 | ||||
|     This efficiently handles users joining/leaving the room by not invalidating | ||||
|     the entire cache for the room. | ||||
| 
 | ||||
|     A new instance is constructed for each call to | ||||
|     `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from | ||||
|     previous calls passed in. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|  | @ -295,6 +341,8 @@ class RulesForRoom: | |||
|         room_id: str, | ||||
|         rules_for_room_cache: LruCache, | ||||
|         room_push_rule_cache_metrics: CacheMetric, | ||||
|         linearizer: Linearizer, | ||||
|         cached_data: RulesForRoomData, | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|  | @ -303,38 +351,21 @@ class RulesForRoom: | |||
|             rules_for_room_cache: The cache object that caches these | ||||
|                 RoomsForUser objects. | ||||
|             room_push_rule_cache_metrics: The metrics object | ||||
|             linearizer: The linearizer used to ensure only one thing mutates | ||||
|                 the cache at a time. Keyed off room_id | ||||
|             cached_data: Cached data from previous calls to `self.get_rules`, | ||||
|                 can be mutated. | ||||
|         """ | ||||
|         self.room_id = room_id | ||||
|         self.is_mine_id = hs.is_mine_id | ||||
|         self.store = hs.get_datastore() | ||||
|         self.room_push_rule_cache_metrics = room_push_rule_cache_metrics | ||||
| 
 | ||||
|         self.linearizer = Linearizer(name="rules_for_room") | ||||
|         # Used to ensure only one thing mutates the cache at a time. Keyed off | ||||
|         # room_id. | ||||
|         self.linearizer = linearizer | ||||
| 
 | ||||
|         # event_id -> (user_id, state) | ||||
|         self.member_map = {}  # type: Dict[str, Tuple[str, str]] | ||||
|         # user_id -> rules | ||||
|         self.rules_by_user = {}  # type: Dict[str, List[Dict[str, dict]]] | ||||
| 
 | ||||
|         # The last state group we updated the caches for. If the state_group of | ||||
|         # a new event comes along, we know that we can just return the cached | ||||
|         # result. | ||||
|         # On invalidation of the rules themselves (if the user changes them), | ||||
|         # we invalidate everything and set state_group to `object()` | ||||
|         self.state_group = object() | ||||
| 
 | ||||
|         # A sequence number to keep track of when we're allowed to update the | ||||
|         # cache. We bump the sequence number when we invalidate the cache. If | ||||
|         # the sequence number changes while we're calculating stuff we should | ||||
|         # not update the cache with it. | ||||
|         self.sequence = 0 | ||||
| 
 | ||||
|         # A cache of user_ids that we *know* aren't interesting, e.g. user_ids | ||||
|         # owned by AS's, or remote users, etc. (I.e. users we will never need to | ||||
|         # calculate push for) | ||||
|         # These never need to be invalidated as we will never set up push for | ||||
|         # them. | ||||
|         self.uninteresting_user_set = set()  # type: Set[str] | ||||
|         self.data = cached_data | ||||
| 
 | ||||
|         # We need to be clever on the invalidating caches callbacks, as | ||||
|         # otherwise the invalidation callback holds a reference to the object, | ||||
|  | @ -352,25 +383,25 @@ class RulesForRoom: | |||
|         """ | ||||
|         state_group = context.state_group | ||||
| 
 | ||||
|         if state_group and self.state_group == state_group: | ||||
|         if state_group and self.data.state_group == state_group: | ||||
|             logger.debug("Using cached rules for %r", self.room_id) | ||||
|             self.room_push_rule_cache_metrics.inc_hits() | ||||
|             return self.rules_by_user | ||||
|             return self.data.rules_by_user | ||||
| 
 | ||||
|         with (await self.linearizer.queue(())): | ||||
|             if state_group and self.state_group == state_group: | ||||
|         with (await self.linearizer.queue(self.room_id)): | ||||
|             if state_group and self.data.state_group == state_group: | ||||
|                 logger.debug("Using cached rules for %r", self.room_id) | ||||
|                 self.room_push_rule_cache_metrics.inc_hits() | ||||
|                 return self.rules_by_user | ||||
|                 return self.data.rules_by_user | ||||
| 
 | ||||
|             self.room_push_rule_cache_metrics.inc_misses() | ||||
| 
 | ||||
|             ret_rules_by_user = {} | ||||
|             missing_member_event_ids = {} | ||||
|             if state_group and self.state_group == context.prev_group: | ||||
|             if state_group and self.data.state_group == context.prev_group: | ||||
|                 # If we have a simple delta then we can reuse most of the previous | ||||
|                 # results. | ||||
|                 ret_rules_by_user = self.rules_by_user | ||||
|                 ret_rules_by_user = self.data.rules_by_user | ||||
|                 current_state_ids = context.delta_ids | ||||
| 
 | ||||
|                 push_rules_delta_state_cache_metric.inc_hits() | ||||
|  | @ -393,24 +424,24 @@ class RulesForRoom: | |||
|                 if typ != EventTypes.Member: | ||||
|                     continue | ||||
| 
 | ||||
|                 if user_id in self.uninteresting_user_set: | ||||
|                 if user_id in self.data.uninteresting_user_set: | ||||
|                     continue | ||||
| 
 | ||||
|                 if not self.is_mine_id(user_id): | ||||
|                     self.uninteresting_user_set.add(user_id) | ||||
|                     self.data.uninteresting_user_set.add(user_id) | ||||
|                     continue | ||||
| 
 | ||||
|                 if self.store.get_if_app_services_interested_in_user(user_id): | ||||
|                     self.uninteresting_user_set.add(user_id) | ||||
|                     self.data.uninteresting_user_set.add(user_id) | ||||
|                     continue | ||||
| 
 | ||||
|                 event_id = current_state_ids[key] | ||||
| 
 | ||||
|                 res = self.member_map.get(event_id, None) | ||||
|                 res = self.data.member_map.get(event_id, None) | ||||
|                 if res: | ||||
|                     user_id, state = res | ||||
|                     if state == Membership.JOIN: | ||||
|                         rules = self.rules_by_user.get(user_id, None) | ||||
|                         rules = self.data.rules_by_user.get(user_id, None) | ||||
|                         if rules: | ||||
|                             ret_rules_by_user[user_id] = rules | ||||
|                     continue | ||||
|  | @ -430,7 +461,7 @@ class RulesForRoom: | |||
|             else: | ||||
|                 # The push rules didn't change but lets update the cache anyway | ||||
|                 self.update_cache( | ||||
|                     self.sequence, | ||||
|                     self.data.sequence, | ||||
|                     members={},  # There were no membership changes | ||||
|                     rules_by_user=ret_rules_by_user, | ||||
|                     state_group=state_group, | ||||
|  | @ -461,7 +492,7 @@ class RulesForRoom: | |||
|                 for. Used when updating the cache. | ||||
|             event: The event we are currently computing push rules for. | ||||
|         """ | ||||
|         sequence = self.sequence | ||||
|         sequence = self.data.sequence | ||||
| 
 | ||||
|         rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) | ||||
| 
 | ||||
|  | @ -501,23 +532,11 @@ class RulesForRoom: | |||
| 
 | ||||
|         self.update_cache(sequence, members, ret_rules_by_user, state_group) | ||||
| 
 | ||||
|     def invalidate_all(self) -> None: | ||||
|         # Note: Don't hand this function directly to an invalidation callback | ||||
|         # as it keeps a reference to self and will stop this instance from being | ||||
|         # GC'd if it gets dropped from the rules_to_user cache. Instead use | ||||
|         # `self.invalidate_all_cb` | ||||
|         logger.debug("Invalidating RulesForRoom for %r", self.room_id) | ||||
|         self.sequence += 1 | ||||
|         self.state_group = object() | ||||
|         self.member_map = {} | ||||
|         self.rules_by_user = {} | ||||
|         push_rules_invalidation_counter.inc() | ||||
| 
 | ||||
|     def update_cache(self, sequence, members, rules_by_user, state_group) -> None: | ||||
|         if sequence == self.sequence: | ||||
|             self.member_map.update(members) | ||||
|             self.rules_by_user = rules_by_user | ||||
|             self.state_group = state_group | ||||
|         if sequence == self.data.sequence: | ||||
|             self.data.member_map.update(members) | ||||
|             self.data.rules_by_user = rules_by_user | ||||
|             self.data.state_group = state_group | ||||
| 
 | ||||
| 
 | ||||
| @attr.attrs(slots=True, frozen=True) | ||||
|  | @ -535,6 +554,10 @@ class _Invalidation: | |||
|     room_id = attr.ib(type=str) | ||||
| 
 | ||||
|     def __call__(self) -> None: | ||||
|         rules = self.cache.get(self.room_id, None, update_metrics=False) | ||||
|         if rules: | ||||
|             rules.invalidate_all() | ||||
|         rules_data = self.cache.get(self.room_id, None, update_metrics=False) | ||||
|         if rules_data: | ||||
|             rules_data.sequence += 1 | ||||
|             rules_data.state_group = object() | ||||
|             rules_data.member_map = {} | ||||
|             rules_data.rules_by_user = {} | ||||
|             push_rules_invalidation_counter.inc() | ||||
|  |  | |||
|  | @ -23,8 +23,11 @@ from typing import ( | |||
|     Optional, | ||||
|     Set, | ||||
|     Tuple, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
| import attr | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.events import EventBase | ||||
| from synapse.events.snapshot import EventContext | ||||
|  | @ -43,7 +46,7 @@ from synapse.storage.roommember import ( | |||
|     ProfileInfo, | ||||
|     RoomsForUser, | ||||
| ) | ||||
| from synapse.types import PersistedEventPosition, get_domain_from_id | ||||
| from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id | ||||
| from synapse.util.async_helpers import Linearizer | ||||
| from synapse.util.caches import intern_string | ||||
| from synapse.util.caches.descriptors import _CacheContext, cached, cachedList | ||||
|  | @ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
|     def __init__(self, database: DatabasePool, db_conn, hs): | ||||
|         super().__init__(database, db_conn, hs) | ||||
| 
 | ||||
|         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache | ||||
|         # at a time. Keyed by room_id. | ||||
|         self._joined_host_linearizer = Linearizer("_JoinedHostsCache") | ||||
| 
 | ||||
|         # Is the current_state_events.membership up to date? Or is the | ||||
|         # background update still running? | ||||
|         self._current_state_events_membership_up_to_date = False | ||||
|  | @ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
| 
 | ||||
|     @cached(num_args=2, max_entries=10000, iterable=True) | ||||
|     async def _get_joined_hosts( | ||||
|         self, room_id, state_group, current_state_ids, state_entry | ||||
|     ): | ||||
|         # We don't use `state_group`, its there so that we can cache based | ||||
|         # on it. However, its important that its never None, since two current_state's | ||||
|         # with a state_group of None are likely to be different. | ||||
|         self, | ||||
|         room_id: str, | ||||
|         state_group: int, | ||||
|         current_state_ids: StateMap[str], | ||||
|         state_entry: "_StateCacheEntry", | ||||
|     ) -> FrozenSet[str]: | ||||
|         # We don't use `state_group`, its there so that we can cache based on | ||||
|         # it. However, its important that its never None, since two | ||||
|         # current_state's with a state_group of None are likely to be different. | ||||
|         # | ||||
|         # The `state_group` must match the `state_entry.state_group` (if not None). | ||||
|         assert state_group is not None | ||||
|         assert state_entry.state_group is None or state_entry.state_group == state_group | ||||
| 
 | ||||
|         # We use a secondary cache of previous work to allow us to build up the | ||||
|         # joined hosts for the given state group based on previous state groups. | ||||
|         # | ||||
|         # We cache one object per room containing the results of the last state | ||||
|         # group we got joined hosts for. The idea is that generally | ||||
|         # `get_joined_hosts` is called with the "current" state group for the | ||||
|         # room, and so consecutive calls will be for consecutive state groups | ||||
|         # which point to the previous state group. | ||||
|         cache = await self._get_joined_hosts_cache(room_id) | ||||
|         return await cache.get_destinations(state_entry) | ||||
| 
 | ||||
|         # If the state group in the cache matches, we already have the data we need. | ||||
|         if state_entry.state_group == cache.state_group: | ||||
|             return frozenset(cache.hosts_to_joined_users) | ||||
| 
 | ||||
|         # Since we'll mutate the cache we need to lock. | ||||
|         with (await self._joined_host_linearizer.queue(room_id)): | ||||
|             if state_entry.state_group == cache.state_group: | ||||
|                 # Same state group, so nothing to do. We've already checked for | ||||
|                 # this above, but the cache may have changed while waiting on | ||||
|                 # the lock. | ||||
|                 pass | ||||
|             elif state_entry.prev_group == cache.state_group: | ||||
|                 # The cached work is for the previous state group, so we work out | ||||
|                 # the delta. | ||||
|                 for (typ, state_key), event_id in state_entry.delta_ids.items(): | ||||
|                     if typ != EventTypes.Member: | ||||
|                         continue | ||||
| 
 | ||||
|                     host = intern_string(get_domain_from_id(state_key)) | ||||
|                     user_id = state_key | ||||
|                     known_joins = cache.hosts_to_joined_users.setdefault(host, set()) | ||||
| 
 | ||||
|                     event = await self.get_event(event_id) | ||||
|                     if event.membership == Membership.JOIN: | ||||
|                         known_joins.add(user_id) | ||||
|                     else: | ||||
|                         known_joins.discard(user_id) | ||||
| 
 | ||||
|                         if not known_joins: | ||||
|                             cache.hosts_to_joined_users.pop(host, None) | ||||
|             else: | ||||
|                 # The cache doesn't match the state group or prev state group, | ||||
|                 # so we calculate the result from first principles. | ||||
|                 joined_users = await self.get_joined_users_from_state( | ||||
|                     room_id, state_entry | ||||
|                 ) | ||||
| 
 | ||||
|                 cache.hosts_to_joined_users = {} | ||||
|                 for user_id in joined_users: | ||||
|                     host = intern_string(get_domain_from_id(user_id)) | ||||
|                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) | ||||
| 
 | ||||
|             if state_entry.state_group: | ||||
|                 cache.state_group = state_entry.state_group | ||||
|             else: | ||||
|                 cache.state_group = object() | ||||
| 
 | ||||
|         return frozenset(cache.hosts_to_joined_users) | ||||
| 
 | ||||
|     @cached(max_entries=10000) | ||||
|     def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": | ||||
|         return _JoinedHostsCache(self, room_id) | ||||
|         return _JoinedHostsCache() | ||||
| 
 | ||||
|     @cached(num_args=2) | ||||
|     async def did_forget(self, user_id: str, room_id: str) -> bool: | ||||
|  | @ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
|         await self.db_pool.runInteraction("forget_membership", f) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True) | ||||
| class _JoinedHostsCache: | ||||
|     """Cache for joined hosts in a room that is optimised to handle updates | ||||
|     via state deltas. | ||||
|     """ | ||||
|     """The cached data used by the `_get_joined_hosts_cache`.""" | ||||
| 
 | ||||
|     def __init__(self, store, room_id): | ||||
|         self.store = store | ||||
|         self.room_id = room_id | ||||
|     # Dict of host to the set of their users in the room at the state group. | ||||
|     hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) | ||||
| 
 | ||||
|         self.hosts_to_joined_users = {} | ||||
| 
 | ||||
|         self.state_group = object() | ||||
| 
 | ||||
|         self.linearizer = Linearizer("_JoinedHostsCache") | ||||
| 
 | ||||
|         self._len = 0 | ||||
| 
 | ||||
|     async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: | ||||
|         """Get set of destinations for a state entry | ||||
| 
 | ||||
|         Args: | ||||
|             state_entry | ||||
| 
 | ||||
|         Returns: | ||||
|             The destinations as a set. | ||||
|         """ | ||||
|         if state_entry.state_group == self.state_group: | ||||
|             return frozenset(self.hosts_to_joined_users) | ||||
| 
 | ||||
|         with (await self.linearizer.queue(())): | ||||
|             if state_entry.state_group == self.state_group: | ||||
|                 pass | ||||
|             elif state_entry.prev_group == self.state_group: | ||||
|                 for (typ, state_key), event_id in state_entry.delta_ids.items(): | ||||
|                     if typ != EventTypes.Member: | ||||
|                         continue | ||||
| 
 | ||||
|                     host = intern_string(get_domain_from_id(state_key)) | ||||
|                     user_id = state_key | ||||
|                     known_joins = self.hosts_to_joined_users.setdefault(host, set()) | ||||
| 
 | ||||
|                     event = await self.store.get_event(event_id) | ||||
|                     if event.membership == Membership.JOIN: | ||||
|                         known_joins.add(user_id) | ||||
|                     else: | ||||
|                         known_joins.discard(user_id) | ||||
| 
 | ||||
|                         if not known_joins: | ||||
|                             self.hosts_to_joined_users.pop(host, None) | ||||
|             else: | ||||
|                 joined_users = await self.store.get_joined_users_from_state( | ||||
|                     self.room_id, state_entry | ||||
|                 ) | ||||
| 
 | ||||
|                 self.hosts_to_joined_users = {} | ||||
|                 for user_id in joined_users: | ||||
|                     host = intern_string(get_domain_from_id(user_id)) | ||||
|                     self.hosts_to_joined_users.setdefault(host, set()).add(user_id) | ||||
| 
 | ||||
|             if state_entry.state_group: | ||||
|                 self.state_group = state_entry.state_group | ||||
|             else: | ||||
|                 self.state_group = object() | ||||
|             self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) | ||||
|         return frozenset(self.hosts_to_joined_users) | ||||
|     # The state group `hosts_to_joined_users` is derived from. Will be an object | ||||
|     # if the instance is newly created or if the state is not based on a state | ||||
|     # group. (An object is used as a sentinel value to ensure that it never is | ||||
|     # equal to anything else). | ||||
|     state_group = attr.ib(type=Union[object, int], factory=object) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self._len | ||||
|         return sum(len(v) for v in self.hosts_to_joined_users.values()) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston