Improve exception handling for concurrent execution (#12109)
* fix incorrect unwrapFirstError import this was being imported from the wrong place * Refactor `concurrently_execute` to use `yieldable_gather_results` * Improve exception handling in `yieldable_gather_results` Try to avoid swallowing so many stack traces. * mark unwrapFirstError deprecated * changelogpull/12111/head
parent
952efd0bca
commit
9d11fee8f2
|
@ -0,0 +1 @@
|
||||||
|
Improve exception handling for concurrent execution.
|
|
@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||||
from synapse.util import json_decoder, json_encoder, log_failure
|
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
|
from synapse.util.async_helpers import Linearizer, gather_results
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
|
@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||||
|
|
||||||
|
|
||||||
def unwrapFirstError(failure: Failure) -> Failure:
|
def unwrapFirstError(failure: Failure) -> Failure:
|
||||||
# defer.gatherResults and DeferredLists wrap failures.
|
# Deprecated: you probably just want to catch defer.FirstError and reraise
|
||||||
|
# the subFailure's value, which will do a better job of preserving stacktraces.
|
||||||
|
# (actually, you probably want to use yieldable_gather_results anyway)
|
||||||
failure.trap(defer.FirstError)
|
failure.trap(defer.FirstError)
|
||||||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ from typing import (
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -51,7 +52,7 @@ from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
run_in_background,
|
run_in_background,
|
||||||
)
|
)
|
||||||
from synapse.util import Clock, unwrapFirstError
|
from synapse.util import Clock
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def concurrently_execute(
|
async def concurrently_execute(
|
||||||
func: Callable[[T], Any], args: Iterable[T], limit: int
|
func: Callable[[T], Any], args: Iterable[T], limit: int
|
||||||
) -> defer.Deferred:
|
) -> None:
|
||||||
"""Executes the function with each argument concurrently while limiting
|
"""Executes the function with each argument concurrently while limiting
|
||||||
the number of concurrent executions.
|
the number of concurrent executions.
|
||||||
|
|
||||||
|
@ -221,20 +222,14 @@ def concurrently_execute(
|
||||||
# We use `itertools.islice` to handle the case where the number of args is
|
# We use `itertools.islice` to handle the case where the number of args is
|
||||||
# less than the limit, avoiding needlessly spawning unnecessary background
|
# less than the limit, avoiding needlessly spawning unnecessary background
|
||||||
# tasks.
|
# tasks.
|
||||||
return make_deferred_yieldable(
|
await yieldable_gather_results(
|
||||||
defer.gatherResults(
|
_concurrently_execute_inner, (value for value in itertools.islice(it, limit))
|
||||||
[
|
)
|
||||||
run_in_background(_concurrently_execute_inner, value)
|
|
||||||
for value in itertools.islice(it, limit)
|
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
)
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
|
|
||||||
def yieldable_gather_results(
|
async def yieldable_gather_results(
|
||||||
func: Callable, iter: Iterable, *args: Any, **kwargs: Any
|
func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
|
||||||
) -> defer.Deferred:
|
) -> List[T]:
|
||||||
"""Executes the function with each argument concurrently.
|
"""Executes the function with each argument concurrently.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -245,15 +240,30 @@ def yieldable_gather_results(
|
||||||
**kwargs: Keyword arguments to be passed to each call to func
|
**kwargs: Keyword arguments to be passed to each call to func
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
Deferred[list]: Resolved when all functions have been invoked, or errors if
|
A list containing the results of the function
|
||||||
one of the function calls fails.
|
|
||||||
"""
|
"""
|
||||||
return make_deferred_yieldable(
|
try:
|
||||||
defer.gatherResults(
|
return await make_deferred_yieldable(
|
||||||
[run_in_background(func, item, *args, **kwargs) for item in iter],
|
defer.gatherResults(
|
||||||
consumeErrors=True,
|
[run_in_background(func, item, *args, **kwargs) for item in iter],
|
||||||
|
consumeErrors=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).addErrback(unwrapFirstError)
|
except defer.FirstError as dfe:
|
||||||
|
# unwrap the error from defer.gatherResults.
|
||||||
|
|
||||||
|
# The raised exception's traceback only includes func() etc if
|
||||||
|
# the 'await' happens before the exception is thrown - ie if the failure
|
||||||
|
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
|
||||||
|
# could be large.
|
||||||
|
#
|
||||||
|
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
|
||||||
|
# we could throw Twisted into the fires of Mordor.
|
||||||
|
|
||||||
|
# suppress exception chaining, because the FirstError doesn't tell us anything
|
||||||
|
# very interesting.
|
||||||
|
assert isinstance(dfe.subFailure.value, BaseException)
|
||||||
|
raise dfe.subFailure.value from None
|
||||||
|
|
||||||
|
|
||||||
T1 = TypeVar("T1")
|
T1 = TypeVar("T1")
|
||||||
|
|
|
@ -11,9 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import traceback
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError, Deferred
|
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
|
||||||
from twisted.internet.task import Clock
|
from twisted.internet.task import Clock
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
SENTINEL_CONTEXT,
|
SENTINEL_CONTEXT,
|
||||||
|
@ -21,7 +24,11 @@ from synapse.logging.context import (
|
||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
current_context,
|
current_context,
|
||||||
)
|
)
|
||||||
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
|
from synapse.util.async_helpers import (
|
||||||
|
ObservableDeferred,
|
||||||
|
concurrently_execute,
|
||||||
|
timeout_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
@ -171,3 +178,107 @@ class TimeoutDeferredTest(TestCase):
|
||||||
)
|
)
|
||||||
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||||
self.assertIs(current_context(), context_one)
|
self.assertIs(current_context(), context_one)
|
||||||
|
|
||||||
|
|
||||||
|
class _TestException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrentlyExecuteTest(TestCase):
|
||||||
|
def test_limits_runners(self):
|
||||||
|
"""If we have more tasks than runners, we should get the limit of runners"""
|
||||||
|
started = 0
|
||||||
|
waiters = []
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
async def callback(v):
|
||||||
|
# when we first enter, bump the start count
|
||||||
|
nonlocal started
|
||||||
|
started += 1
|
||||||
|
|
||||||
|
# record the fact we got an item
|
||||||
|
processed.append(v)
|
||||||
|
|
||||||
|
# wait for the goahead before returning
|
||||||
|
d2 = Deferred()
|
||||||
|
waiters.append(d2)
|
||||||
|
await d2
|
||||||
|
|
||||||
|
# set it going
|
||||||
|
d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
|
||||||
|
|
||||||
|
# check we got exactly 3 processes
|
||||||
|
self.assertEqual(started, 3)
|
||||||
|
self.assertEqual(len(waiters), 3)
|
||||||
|
|
||||||
|
# let one finish
|
||||||
|
waiters.pop().callback(0)
|
||||||
|
|
||||||
|
# ... which should start another
|
||||||
|
self.assertEqual(started, 4)
|
||||||
|
self.assertEqual(len(waiters), 3)
|
||||||
|
|
||||||
|
# we still shouldn't be done
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
# finish the job
|
||||||
|
while waiters:
|
||||||
|
waiters.pop().callback(0)
|
||||||
|
|
||||||
|
# check everything got done
|
||||||
|
self.assertEqual(started, 5)
|
||||||
|
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
|
||||||
|
self.successResultOf(d2)
|
||||||
|
|
||||||
|
def test_preserves_stacktraces(self):
|
||||||
|
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
|
||||||
|
d1 = Deferred()
|
||||||
|
|
||||||
|
async def callback(v):
|
||||||
|
# alas, this doesn't work at all without an await here
|
||||||
|
await d1
|
||||||
|
raise _TestException("bah")
|
||||||
|
|
||||||
|
async def caller():
|
||||||
|
try:
|
||||||
|
await concurrently_execute(callback, [1], 2)
|
||||||
|
except _TestException as e:
|
||||||
|
tb = traceback.extract_tb(e.__traceback__)
|
||||||
|
# we expect to see "caller", "concurrently_execute" and "callback".
|
||||||
|
self.assertEqual(tb[0].name, "caller")
|
||||||
|
self.assertEqual(tb[1].name, "concurrently_execute")
|
||||||
|
self.assertEqual(tb[-1].name, "callback")
|
||||||
|
else:
|
||||||
|
self.fail("No exception thrown")
|
||||||
|
|
||||||
|
d2 = ensureDeferred(caller())
|
||||||
|
d1.callback(0)
|
||||||
|
self.successResultOf(d2)
|
||||||
|
|
||||||
|
def test_preserves_stacktraces_on_preformed_failure(self):
|
||||||
|
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
|
||||||
|
d1 = Deferred()
|
||||||
|
f = Failure(_TestException("bah"))
|
||||||
|
|
||||||
|
async def callback(v):
|
||||||
|
# alas, this doesn't work at all without an await here
|
||||||
|
await d1
|
||||||
|
await defer.fail(f)
|
||||||
|
|
||||||
|
async def caller():
|
||||||
|
try:
|
||||||
|
await concurrently_execute(callback, [1], 2)
|
||||||
|
except _TestException as e:
|
||||||
|
tb = traceback.extract_tb(e.__traceback__)
|
||||||
|
# we expect to see "caller", "concurrently_execute", "callback",
|
||||||
|
# and some magic from inside ensureDeferred that happens when .fail
|
||||||
|
# is called.
|
||||||
|
self.assertEqual(tb[0].name, "caller")
|
||||||
|
self.assertEqual(tb[1].name, "concurrently_execute")
|
||||||
|
self.assertEqual(tb[-2].name, "callback")
|
||||||
|
else:
|
||||||
|
self.fail("No exception thrown")
|
||||||
|
|
||||||
|
d2 = ensureDeferred(caller())
|
||||||
|
d1.callback(0)
|
||||||
|
self.successResultOf(d2)
|
||||||
|
|
Loading…
Reference in New Issue