Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

pull/2587/head 2017-05-23_1048
Erik Johnston 2017-05-22 16:56:24 +01:00
commit 59159ed81d
2 changed files with 95 additions and 29 deletions

View File

@ -53,7 +53,7 @@ class BulkPushRuleEvaluator(object):
room_id = event.room_id
rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(context)
rules_by_user = yield rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
@ -200,6 +200,13 @@ class RulesForRoom(object):
# not update the cache with it.
self.sequence = 0
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
self.uninteresting_user_set = set()
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
# potentially causing it to leak.
@ -209,7 +216,7 @@ class RulesForRoom(object):
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks
def get_rules(self, context):
def get_rules(self, event, context):
"""Given an event context return the rules for all users who are
currently in the room.
"""
@ -217,6 +224,7 @@ class RulesForRoom(object):
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
defer.returnValue(self.rules_by_user)
ret_rules_by_user = {}
@ -229,12 +237,30 @@ class RulesForRoom(object):
else:
current_state_ids = context.current_state_ids
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
)
# Loop through to see which member events we've seen and have rules
# for and which we need to fetch
for key, event_id in current_state_ids.iteritems():
if key[0] != EventTypes.Member:
for key in current_state_ids:
typ, user_id = key
if typ != EventTypes.Member:
continue
if user_id in self.uninteresting_user_set:
continue
if not self.is_mine_id(user_id):
self.uninteresting_user_set.add(user_id)
continue
if self.store.get_if_app_services_interested_in_user(user_id):
self.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
res = self.member_map.get(event_id, None)
if res:
user_id, state = res
@ -244,13 +270,6 @@ class RulesForRoom(object):
ret_rules_by_user[user_id] = rules
continue
user_id = key[1]
if not self.is_mine_id(user_id):
continue
if self.store.get_if_app_services_interested_in_user(user_id):
continue
# If a user has left a room we remove their push rule. If they
# joined then we readd it later in _update_rules_with_member_event_ids
ret_rules_by_user.pop(user_id, None)
@ -259,15 +278,21 @@ class RulesForRoom(object):
if missing_member_event_ids:
# If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group
ret_rules_by_user, missing_member_event_ids, state_group, event
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Returning push rules for %r %r",
self.room_id, ret_rules_by_user.keys(),
)
defer.returnValue(ret_rules_by_user)
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
state_group):
state_group, event):
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
@ -296,11 +321,23 @@ class RulesForRoom(object):
for row in rows
}
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.itervalues():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
user_id for user_id, membership in members.itervalues()
if membership == Membership.JOIN
)
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
interested_in_user_ids,
on_invalidate=self.invalidate_all_cb,
@ -310,10 +347,14 @@ class RulesForRoom(object):
uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
)
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb,
)
logger.debug("With receipts: %r", users_with_receipts)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
@ -334,6 +375,7 @@ class RulesForRoom(object):
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
# `self.invalidate_all_cb`
logger.debug("Invalidating RulesForRoom for %r", self.room_id)
self.sequence += 1
self.state_group = object()
self.member_map = {}

View File

@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.invalidate_all = cache.invalidate_all
wrapped.cache = cache
wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped
@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
def __get__(self, obj, objtype=None):
cache = getattr(obj, self.cached_method_name).cache
cached_method = getattr(obj, self.cached_method_name)
cache = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
@ -469,12 +471,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
results = {}
cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
def cache_get(arg):
return cache.get(arg, callback=invalidate_callback)
else:
key = list(keyargs)
def cache_get(arg):
key[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback)
for arg in list_args:
try:
res = cache.get(tuple(key), callback=invalidate_callback)
res = cache_get(arg)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@ -505,17 +518,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)