Track and deduplicate in-flight requests to `_get_state_for_groups`. (#10870)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
pull/12039/head
reivilibre 2022-02-18 17:23:31 +00:00 committed by GitHub
parent e6acd3cf4f
commit 284ea2025a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 312 additions and 25 deletions

1
changelog.d/10870.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

View File

@ -13,11 +13,23 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -37,7 +55,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.
Used as part of _get_state_for_group_using_inflight_cache.
Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""
inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over
# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
for (request_state_filter, request_deferred) in inflight_requests.items():
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue
reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break
return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.
This request will be tracked in the in-flight request cache and automatically
removed when it is finished.
Used as part of _get_state_for_group_using_inflight_cache.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)
# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]
return group_to_state_dict[group]
# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
async def _get_state_for_group_using_inflight_cache(
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.
1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.
Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""
# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}
gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)
# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not incomplete_groups:
return state
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
async def get_from_cache(group: int, state_filter: StateFilter) -> None:
state[group] = await self._get_state_for_group_using_inflight_cache(
group, state_filter
)
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
await yieldable_gather_results(
get_from_cache,
incomplete_groups,
state_filter,
)
# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
return state
def _get_state_for_groups_using_cache(

View File

@ -0,0 +1,133 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Dict, List, Sequence, Tuple
from unittest.mock import patch
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
if typing.TYPE_CHECKING:
from synapse.server import HomeServer
class StateGroupInflightCachingTestCase(HomeserverTestCase):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
) -> None:
self.state_storage = homeserver.get_storage().state
self.state_datastore = homeserver.get_datastores().state
# Patch out the `_get_state_groups_from_groups`.
# This is useful because it lets us pretend we have a slow database.
get_state_groups_patch = patch.object(
self.state_datastore,
"_get_state_groups_from_groups",
self._fake_get_state_groups_from_groups,
)
get_state_groups_patch.start()
self.addCleanup(get_state_groups_patch.stop)
self.get_state_group_calls: List[
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
] = []
def _fake_get_state_groups_from_groups(
self, groups: Sequence[int], state_filter: StateFilter
) -> "Deferred[Dict[int, StateMap[str]]]":
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
self.get_state_group_calls.append((tuple(groups), state_filter, d))
return d
def _complete_request_fake(
self,
groups: Tuple[int, ...],
state_filter: StateFilter,
d: "Deferred[Dict[int, StateMap[str]]]",
) -> None:
"""
Assemble a fake database response and complete the database request.
"""
result: Dict[int, StateMap[str]] = {}
for group in groups:
group_result: MutableStateMap[str] = {}
result[group] = group_result
for state_type, state_keys in state_filter.types.items():
if state_keys is None:
group_result[(state_type, "a")] = "xyz"
group_result[(state_type, "b")] = "xyz"
else:
for state_key in state_keys:
group_result[(state_type, state_key)] = "abc"
if state_filter.include_others:
group_result[("other.event.type", "state.key")] = "123"
d.callback(result)
def test_duplicate_requests_deduplicated(self) -> None:
"""
Tests that duplicate requests for state are deduplicated.
This test:
- requests some state (state group 42, 'all' state filter)
- requests it again, before the first request finishes
- checks to see that only one database query was made
- completes the database query
- checks that both requests see the same retrieved state
"""
req1 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# This should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)
req2 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# No more calls should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)
self.assertFalse(req2.called)
groups, sf, d = self.get_state_group_calls[0]
self.assertEqual(groups, (42,))
self.assertEqual(sf, StateFilter.all())
# Now we can complete the request
self._complete_request_fake(groups, sf, d)
self.assertEqual(
self.get_success(req1), {("other.event.type", "state.key"): "123"}
)
self.assertEqual(
self.get_success(req2), {("other.event.type", "state.key"): "123"}
)