diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 760d567ca1..9a96e6fe8f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -53,7 +53,7 @@ class BulkPushRuleEvaluator(object): room_id = event.room_id rules_for_room = self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(context) + rules_by_user = yield rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited @@ -200,6 +200,13 @@ class RulesForRoom(object): # not update the cache with it. self.sequence = 0 + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + self.uninteresting_user_set = set() + # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, # potentially causing it to leak. @@ -209,7 +216,7 @@ class RulesForRoom(object): self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) @defer.inlineCallbacks - def get_rules(self, context): + def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -217,6 +224,7 @@ class RulesForRoom(object): with (yield self.linearizer.queue(())): if state_group and self.state_group == state_group: + logger.debug("Using cached rules for %r", self.room_id) defer.returnValue(self.rules_by_user) ret_rules_by_user = {} @@ -229,12 +237,30 @@ class RulesForRoom(object): else: current_state_ids = context.current_state_ids + logger.debug( + "Looking for member changes in %r %r", state_group, current_state_ids + ) + # Loop through to see which member events we've seen and have rules # for and which we need to fetch - for key, event_id in current_state_ids.iteritems(): - if key[0] != EventTypes.Member: + for key in current_state_ids: + typ, user_id = key + if typ != EventTypes.Member: continue + if user_id in self.uninteresting_user_set: + continue + + if not self.is_mine_id(user_id): + self.uninteresting_user_set.add(user_id) + continue + + if self.store.get_if_app_services_interested_in_user(user_id): + self.uninteresting_user_set.add(user_id) + continue + + event_id = current_state_ids[key] + res = self.member_map.get(event_id, None) if res: user_id, state = res @@ -244,13 +270,6 @@ class RulesForRoom(object): ret_rules_by_user[user_id] = rules continue - user_id = key[1] - if not self.is_mine_id(user_id): - continue - - if self.store.get_if_app_services_interested_in_user(user_id): - continue - # If a user has left a room we remove their push rule. If they # joined then we readd it later in _update_rules_with_member_event_ids ret_rules_by_user.pop(user_id, None) @@ -259,15 +278,21 @@ class RulesForRoom(object): if missing_member_event_ids: # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. + logger.debug("Found new member events %r", missing_member_event_ids) yield self._update_rules_with_member_event_ids( - ret_rules_by_user, missing_member_event_ids, state_group + ret_rules_by_user, missing_member_event_ids, state_group, event ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Returning push rules for %r %r", + self.room_id, ret_rules_by_user.keys(), + ) defer.returnValue(ret_rules_by_user) @defer.inlineCallbacks def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, - state_group): + state_group, event): """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. @@ -296,11 +321,23 @@ class RulesForRoom(object): for row in rows } + # If the event is a join event then it will be in current state evnts + # map but not in the DB, so we have to explicitly insert it. + if event.type == EventTypes.Member: + for event_id in member_event_ids.itervalues(): + if event_id == event.event_id: + members[event_id] = (event.state_key, event.membership) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Found members %r: %r", self.room_id, members.values()) + interested_in_user_ids = set( user_id for user_id, membership in members.itervalues() if membership == Membership.JOIN ) + logger.debug("Joined: %r", interested_in_user_ids) + if_users_with_pushers = yield self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb, @@ -310,10 +347,14 @@ class RulesForRoom(object): uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher ) + logger.debug("With pushers: %r", user_ids) + users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb, ) + logger.debug("With receipts: %r", users_with_receipts) + # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in interested_in_user_ids: @@ -334,6 +375,7 @@ class RulesForRoom(object): # as it keeps a reference to self and will stop this instance from being # GC'd if it gets dropped from the rules_to_user cache. Instead use # `self.invalidate_all_cb` + logger.debug("Invalidating RulesForRoom for %r", self.room_id) self.sequence += 1 self.state_group = object() self.member_map = {} diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 48dcbafeef..cbdff86596 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.invalidate_all = cache.invalidate_all wrapped.cache = cache + wrapped.num_args = self.num_args obj.__dict__[self.orig.__name__] = wrapped @@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase): ) def __get__(self, obj, objtype=None): - - cache = getattr(obj, self.cached_method_name).cache + cached_method = getattr(obj, self.cached_method_name) + cache = cached_method.cache + num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args, **kwargs): @@ -469,12 +471,23 @@ class CacheListDescriptor(_CacheDescriptorBase): results = {} cached_defers = {} missing = [] - for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg + # If the cache takes a single arg then that is used as the key, + # otherwise a tuple is used. + if num_args == 1: + def cache_get(arg): + return cache.get(arg, callback=invalidate_callback) + else: + key = list(keyargs) + + def cache_get(arg): + key[self.list_pos] = arg + return cache.get(tuple(key), callback=invalidate_callback) + + for arg in list_args: try: - res = cache.get(tuple(key), callback=invalidate_callback) + res = cache_get(arg) + if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -505,17 +518,28 @@ class CacheListDescriptor(_CacheDescriptorBase): observer = ObservableDeferred(observer) - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) + if num_args == 1: + cache.set( + arg, observer, + callback=invalidate_callback + ) - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, arg) + else: + key = list(keyargs) + key[self.list_pos] = arg + cache.set( + tuple(key), observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg)