Improve exception handling for concurrent execution (#12109)

* fix incorrect unwrapFirstError import

this was being imported from the wrong place

* Refactor `concurrently_execute` to use `yieldable_gather_results`

* Improve exception handling in `yieldable_gather_results`

Try to avoid swallowing so many stack traces.

* mark unwrapFirstError deprecated

* changelog
pull/12111/head
Richard van der Hoff 2022-03-01 09:34:30 +00:00 committed by GitHub
parent 952efd0bca
commit 9d11fee8f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 27 deletions

1
changelog.d/12109.misc Normal file
View File

@ -0,0 +1 @@
Improve exception handling for concurrent execution.

View File

@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder, json_encoder, log_failure
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client

View File

@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures.
# Deprecated: you probably just want to catch defer.FirstError and reraise
# the subFailure's value, which will do a better job of preserving stacktraces.
# (actually, you probably want to use yieldable_gather_results anyway)
failure.trap(defer.FirstError)
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations

View File

@ -29,6 +29,7 @@ from typing import (
Hashable,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
@ -51,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.util import Clock, unwrapFirstError
from synapse.util import Clock
logger = logging.getLogger(__name__)
@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
T = TypeVar("T")
def concurrently_execute(
async def concurrently_execute(
func: Callable[[T], Any], args: Iterable[T], limit: int
) -> defer.Deferred:
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
@ -221,20 +222,14 @@ def concurrently_execute(
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
return make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(_concurrently_execute_inner, value)
for value in itertools.islice(it, limit)
],
consumeErrors=True,
await yieldable_gather_results(
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
)
).addErrback(unwrapFirstError)
def yieldable_gather_results(
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
) -> defer.Deferred:
async def yieldable_gather_results(
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
) -> List[T]:
"""Executes the function with each argument concurrently.
Args:
@ -245,15 +240,30 @@ def yieldable_gather_results(
**kwargs: Keyword arguments to be passed to each call to func
Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
A list containing the results of the function
"""
return make_deferred_yieldable(
try:
return await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.
# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.
# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
T1 = TypeVar("T1")

View File

@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from synapse.logging.context import (
SENTINEL_CONTEXT,
@ -21,7 +24,11 @@ from synapse.logging.context import (
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
timeout_deferred,
)
from tests.unittest import TestCase
@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
class _TestException(Exception):
pass
class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self):
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
async def callback(v):
# when we first enter, bump the start count
nonlocal started
started += 1
# record the fact we got an item
processed.append(v)
# wait for the goahead before returning
d2 = Deferred()
waiters.append(d2)
await d2
# set it going
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
# check we got exactly 3 processes
self.assertEqual(started, 3)
self.assertEqual(len(waiters), 3)
# let one finish
waiters.pop().callback(0)
# ... which should start another
self.assertEqual(started, 4)
self.assertEqual(len(waiters), 3)
# we still shouldn't be done
self.assertNoResult(d2)
# finish the job
while waiters:
waiters.pop().callback(0)
# check everything got done
self.assertEqual(started, 5)
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
def test_preserves_stacktraces(self):
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred()
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute" and "callback".
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self):
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred()
f = Failure(_TestException("bah"))
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
# we expect to see "caller", "concurrently_execute", "callback",
# and some magic from inside ensureDeferred that happens when .fail
# is called.
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
self.assertEqual(tb[-2].name, "callback")
else:
self.fail("No exception thrown")
d2 = ensureDeferred(caller())
d1.callback(0)
self.successResultOf(d2)