Fail test cases if they fail to await all awaitables (#8690)
parent
46f4be94b4
commit
fd7c743445
|
@ -0,0 +1 @@
|
|||
Fail tests if they do not await coroutines.
|
|
@ -17,8 +17,10 @@
|
|||
"""
|
||||
Utilities for running the unit tests
|
||||
"""
|
||||
import sys
|
||||
import warnings
|
||||
from asyncio import Future
|
||||
from typing import Any, Awaitable, TypeVar
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
|
|||
future = Future() # type: ignore
|
||||
future.set_result(result)
|
||||
return future
|
||||
|
||||
|
||||
def setup_awaitable_errors() -> Callable[[], None]:
|
||||
"""
|
||||
Convert warnings from a non-awaited coroutines into errors.
|
||||
"""
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
|
||||
# unraisablehook was added in Python 3.8.
|
||||
if not hasattr(sys, "unraisablehook"):
|
||||
return lambda: None
|
||||
|
||||
# State shared between unraisablehook and check_for_unraisable_exceptions.
|
||||
unraisable_exceptions = []
|
||||
orig_unraisablehook = sys.unraisablehook # type: ignore
|
||||
|
||||
def unraisablehook(unraisable):
|
||||
unraisable_exceptions.append(unraisable.exc_value)
|
||||
|
||||
def cleanup():
|
||||
"""
|
||||
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
|
||||
"""
|
||||
sys.unraisablehook = orig_unraisablehook # type: ignore
|
||||
if unraisable_exceptions:
|
||||
raise unraisable_exceptions.pop()
|
||||
|
||||
sys.unraisablehook = unraisablehook # type: ignore
|
||||
|
||||
return cleanup
|
||||
|
|
|
@ -54,7 +54,7 @@ from tests.server import (
|
|||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
from tests.test_utils import event_injection
|
||||
from tests.test_utils import event_injection, setup_awaitable_errors
|
||||
from tests.test_utils.logging_setup import setup_logging
|
||||
from tests.utils import default_config, setupdb
|
||||
|
||||
|
@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
|
|||
|
||||
logging.getLogger().setLevel(level)
|
||||
|
||||
# Trial messes with the warnings configuration, thus this has to be
|
||||
# done in the context of an individual TestCase.
|
||||
self.addCleanup(setup_awaitable_errors())
|
||||
|
||||
return orig()
|
||||
|
||||
@around(self)
|
||||
|
|
Loading…
Reference in New Issue