Improve exception handling for concurrent execution (#12109)
* fix incorrect unwrapFirstError import this was being imported from the wrong place * Refactor `concurrently_execute` to use `yieldable_gather_results` * Improve exception handling in `yieldable_gather_results` Try to avoid swallowing so many stack traces. * mark unwrapFirstError deprecated * changelogpull/12111/head
parent
952efd0bca
commit
9d11fee8f2
|
@ -0,0 +1 @@
|
|||
Improve exception handling for concurrent execution.
|
|
@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
|||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||
from synapse.util import json_decoder, json_encoder, log_failure
|
||||
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
|
||||
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, gather_results
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
|
|
@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
|||
|
||||
|
||||
def unwrapFirstError(failure: Failure) -> Failure:
|
||||
# defer.gatherResults and DeferredLists wrap failures.
|
||||
# Deprecated: you probably just want to catch defer.FirstError and reraise
|
||||
# the subFailure's value, which will do a better job of preserving stacktraces.
|
||||
# (actually, you probably want to use yieldable_gather_results anyway)
|
||||
failure.trap(defer.FirstError)
|
||||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from typing import (
|
|||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
|
@ -51,7 +52,7 @@ from synapse.logging.context import (
|
|||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util import Clock, unwrapFirstError
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
def concurrently_execute(
|
||||
async def concurrently_execute(
|
||||
func: Callable[[T], Any], args: Iterable[T], limit: int
|
||||
) -> defer.Deferred:
|
||||
) -> None:
|
||||
"""Executes the function with each argument concurrently while limiting
|
||||
the number of concurrent executions.
|
||||
|
||||
|
@ -221,20 +222,14 @@ def concurrently_execute(
|
|||
# We use `itertools.islice` to handle the case where the number of args is
|
||||
# less than the limit, avoiding needlessly spawning unnecessary background
|
||||
# tasks.
|
||||
return make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(_concurrently_execute_inner, value)
|
||||
for value in itertools.islice(it, limit)
|
||||
],
|
||||
consumeErrors=True,
|
||||
await yieldable_gather_results(
|
||||
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
|
||||
def yieldable_gather_results(
|
||||
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
|
||||
) -> defer.Deferred:
|
||||
async def yieldable_gather_results(
|
||||
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
|
||||
) -> List[T]:
|
||||
"""Executes the function with each argument concurrently.
|
||||
|
||||
Args:
|
||||
|
@ -245,15 +240,30 @@ def yieldable_gather_results(
|
|||
**kwargs: Keyword arguments to be passed to each call to func
|
||||
|
||||
Returns
|
||||
Deferred[list]: Resolved when all functions have been invoked, or errors if
|
||||
one of the function calls fails.
|
||||
A list containing the results of the function
|
||||
"""
|
||||
return make_deferred_yieldable(
|
||||
try:
|
||||
return await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[run_in_background(func, item, *args, **kwargs) for item in iter],
|
||||
consumeErrors=True,
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
except defer.FirstError as dfe:
|
||||
# unwrap the error from defer.gatherResults.
|
||||
|
||||
# The raised exception's traceback only includes func() etc if
|
||||
# the 'await' happens before the exception is thrown - ie if the failure
|
||||
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
|
||||
# could be large.
|
||||
#
|
||||
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
|
||||
# we could throw Twisted into the fires of Mordor.
|
||||
|
||||
# suppress exception chaining, because the FirstError doesn't tell us anything
|
||||
# very interesting.
|
||||
assert isinstance(dfe.subFailure.value, BaseException)
|
||||
raise dfe.subFailure.value from None
|
||||
|
||||
|
||||
T1 = TypeVar("T1")
|
||||
|
|
|
@ -11,9 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import traceback
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import (
|
||||
SENTINEL_CONTEXT,
|
||||
|
@ -21,7 +24,11 @@ from synapse.logging.context import (
|
|||
PreserveLoggingContext,
|
||||
current_context,
|
||||
)
|
||||
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
|
||||
from synapse.util.async_helpers import (
|
||||
ObservableDeferred,
|
||||
concurrently_execute,
|
||||
timeout_deferred,
|
||||
)
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
|
|||
)
|
||||
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||
self.assertIs(current_context(), context_one)
|
||||
|
||||
|
||||
class _TestException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConcurrentlyExecuteTest(TestCase):
|
||||
def test_limits_runners(self):
|
||||
"""If we have more tasks than runners, we should get the limit of runners"""
|
||||
started = 0
|
||||
waiters = []
|
||||
processed = []
|
||||
|
||||
async def callback(v):
|
||||
# when we first enter, bump the start count
|
||||
nonlocal started
|
||||
started += 1
|
||||
|
||||
# record the fact we got an item
|
||||
processed.append(v)
|
||||
|
||||
# wait for the goahead before returning
|
||||
d2 = Deferred()
|
||||
waiters.append(d2)
|
||||
await d2
|
||||
|
||||
# set it going
|
||||
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
|
||||
|
||||
# check we got exactly 3 processes
|
||||
self.assertEqual(started, 3)
|
||||
self.assertEqual(len(waiters), 3)
|
||||
|
||||
# let one finish
|
||||
waiters.pop().callback(0)
|
||||
|
||||
# ... which should start another
|
||||
self.assertEqual(started, 4)
|
||||
self.assertEqual(len(waiters), 3)
|
||||
|
||||
# we still shouldn't be done
|
||||
self.assertNoResult(d2)
|
||||
|
||||
# finish the job
|
||||
while waiters:
|
||||
waiters.pop().callback(0)
|
||||
|
||||
# check everything got done
|
||||
self.assertEqual(started, 5)
|
||||
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
|
||||
self.successResultOf(d2)
|
||||
|
||||
def test_preserves_stacktraces(self):
|
||||
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
|
||||
d1 = Deferred()
|
||||
|
||||
async def callback(v):
|
||||
# alas, this doesn't work at all without an await here
|
||||
await d1
|
||||
raise _TestException("bah")
|
||||
|
||||
async def caller():
|
||||
try:
|
||||
await concurrently_execute(callback, [1], 2)
|
||||
except _TestException as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# we expect to see "caller", "concurrently_execute" and "callback".
|
||||
self.assertEqual(tb[0].name, "caller")
|
||||
self.assertEqual(tb[1].name, "concurrently_execute")
|
||||
self.assertEqual(tb[-1].name, "callback")
|
||||
else:
|
||||
self.fail("No exception thrown")
|
||||
|
||||
d2 = ensureDeferred(caller())
|
||||
d1.callback(0)
|
||||
self.successResultOf(d2)
|
||||
|
||||
def test_preserves_stacktraces_on_preformed_failure(self):
|
||||
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
|
||||
d1 = Deferred()
|
||||
f = Failure(_TestException("bah"))
|
||||
|
||||
async def callback(v):
|
||||
# alas, this doesn't work at all without an await here
|
||||
await d1
|
||||
await defer.fail(f)
|
||||
|
||||
async def caller():
|
||||
try:
|
||||
await concurrently_execute(callback, [1], 2)
|
||||
except _TestException as e:
|
||||
tb = traceback.extract_tb(e.__traceback__)
|
||||
# we expect to see "caller", "concurrently_execute", "callback",
|
||||
# and some magic from inside ensureDeferred that happens when .fail
|
||||
# is called.
|
||||
self.assertEqual(tb[0].name, "caller")
|
||||
self.assertEqual(tb[1].name, "concurrently_execute")
|
||||
self.assertEqual(tb[-2].name, "callback")
|
||||
else:
|
||||
self.fail("No exception thrown")
|
||||
|
||||
d2 = ensureDeferred(caller())
|
||||
d1.callback(0)
|
||||
self.successResultOf(d2)
|
||||
|
|
Loading…
Reference in New Issue