Disable partial state group caching for wildcard lookups

When _get_state_for_groups is given a wildcard filter, just do a complete
lookup. Hopefully this will give us the best of both worlds by not filling up
the ram if we only need one or two keys, but also making the cache still work
for the federation reader usecase.
pull/3383/head
Richard van der Hoff 2018-06-11 23:13:06 +01:00
parent 240f192523
commit 43e02c409d
3 changed files with 61 additions and 32 deletions

View File

@ -526,10 +526,23 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events
with matching types. `types` is a list of `(type, state_key)`, where
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned.
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
types (None|iterable[(str, None|str)]):
indicates the state type/keys required. If None, the whole
state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
"""
if types:
types = frozenset(types)
@ -538,7 +551,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
group, types,
)
results[group] = state_dict_ids
@ -559,22 +572,40 @@ class StateGroupWorkerStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
# which makes wildcard lookups expensive unless we have a complete
# cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types):
types_to_fetch = None
else:
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types
missing_groups, types_to_fetch,
)
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
# update the result, filtering by `types`.
if types:
for k, v in iteritems(group_state_dict):
(typ, _) = k
if k in types or (typ, None) in types:
state_dict[k] = v
else:
state_dict.update(group_state_dict)
# update the cache with all the things we fetched from the
# database.
self._state_group_cache.update(
cache_seq_num,
key=group,
value=state_dict,
full=(types is None),
known_absent=types,
value=group_state_dict,
fetched_keys=types_to_fetch,
)
defer.returnValue(results)
@ -681,7 +712,6 @@ class StateGroupWorkerStore(SQLBaseStore):
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
full=True,
)
return state_group

View File

@ -107,29 +107,28 @@ class DictionaryCache(object):
self.sequence += 1
self.cache.clear()
def update(self, sequence, key, value, full=False, known_absent=None):
def update(self, sequence, key, value, fetched_keys=None):
"""Updates the entry in the cache
Args:
sequence
key
value (dict): The value to update the cache with.
full (bool): Whether the given value is the full dict, or just a
partial subset there of. If not full then any existing entries
for the key will be updated.
known_absent (set): Set of keys that we know don't exist in the full
dict.
key (K)
value (dict[X,Y]): The value to update the cache with.
fetched_keys (None|set[X]): All of the dictionary keys which were
fetched from the database.
If None, this is the complete value for key K. Otherwise, it
is used to infer a list of keys which we know don't exist in
the full dict.
"""
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if known_absent is None:
known_absent = set()
if full:
self._insert(key, value, known_absent)
if fetched_keys is None:
self._insert(key, value, set())
else:
self._update_or_insert(key, value, known_absent)
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent):
# We pop and reinsert as we need to tell the cache the size may have

View File

@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
self.cache.update(seq, key, test_value, full=True)
self.cache.update(seq, key, test_value)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_hit_partial"
}
self.cache.update(seq, key, test_value, full=True)
self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_miss_partial"
}
self.cache.update(seq, key, test_value, full=True)
self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
self.cache.update(seq, key, test_value, full=True)
self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
self.cache.update(seq, key, test_value_1, full=False)
self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
self.cache.update(seq, key, test_value_2, full=False)
self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key)
self.assertEqual(