Cap the number of in-flight requests for state from a single group (#11608)

pull/12060/head
reivilibre 2022-02-22 14:24:31 +00:00 committed by GitHub
parent 7bcc28f82f
commit dcb6a37837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 0 deletions

1
changelog.d/11608.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.

View File

@ -56,6 +56,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -258,6 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Attempts to gather in-flight requests and re-use them to retrieve state Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter. for the given state group, filtered with the given state filter.
If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
and there *still* isn't enough information to complete the request by solely
reusing others, a full state filter will be requested to ensure that subsequent
requests can reuse this request.
Used as part of _get_state_for_group_using_inflight_cache. Used as part of _get_state_for_group_using_inflight_cache.
Returns: Returns:
@ -288,6 +294,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# to cover our StateFilter and give us the state we need. # to cover our StateFilter and give us the state we need.
break break
if (
state_filter_left_over != StateFilter.none()
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
):
# There are too many requests for this group.
# To prevent even more from building up, we request the whole
# state filter to guarantee that we can be reused by any subsequent
# requests for this state group.
return (), StateFilter.all()
return reusable_requests, state_filter_left_over return reusable_requests, state_filter_left_over
async def _get_state_for_group_fire_request( async def _get_state_for_group_fire_request(

View File

@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util import Clock from synapse.util import Clock
@ -281,3 +282,71 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase):
self.assertEqual(self.get_success(req1), FAKE_STATE) self.assertEqual(self.get_success(req1), FAKE_STATE)
self.assertEqual(self.get_success(req2), FAKE_STATE) self.assertEqual(self.get_success(req2), FAKE_STATE)
def test_inflight_requests_capped(self) -> None:
"""
Tests that the number of in-flight requests is capped to 5.
- requests several pieces of state separately
(5 to hit the limit, 1 to 'shunt out', another that comes after the
group has been 'shunted out')
- checks to see that the torrent of requests is shunted out by
rewriting one of the filters as the 'all' state filter
- requests after that one do not cause any additional queries
"""
# 5 at the time of writing.
CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP
reqs = []
# Request 7 different keys (1 to 7) of the `some.state` type.
for req_id in range(CAP_COUNT + 2):
reqs.append(
ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42,
StateFilter.freeze(
{"some.state": {str(req_id + 1)}}, include_others=False
),
)
)
)
self.pump(by=0.1)
# There should only be 6 calls to the database, not 7.
self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1)
# Assert that the first 5 are exact requests for the individual pieces
# wanted
for req_id in range(CAP_COUNT):
groups, sf, d = self.get_state_group_calls[req_id]
self.assertEqual(
sf,
StateFilter.freeze(
{"some.state": {str(req_id + 1)}}, include_others=False
),
)
# The 6th request should be the 'all' state filter
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
self.assertEqual(sf, StateFilter.all())
# Complete the queries and check which requests complete as a result
for req_id in range(CAP_COUNT):
# This request should not have been completed yet
self.assertFalse(reqs[req_id].called)
groups, sf, d = self.get_state_group_calls[req_id]
self._complete_request_fake(groups, sf, d)
# This should have only completed this one request
self.assertTrue(reqs[req_id].called)
# Now complete the final query; the last 2 requests should complete
# as a result
self.assertFalse(reqs[CAP_COUNT].called)
self.assertFalse(reqs[CAP_COUNT + 1].called)
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
self._complete_request_fake(groups, sf, d)
self.assertTrue(reqs[CAP_COUNT].called)
self.assertTrue(reqs[CAP_COUNT + 1].called)