Handle cancellation in `DatabasePool.runInteraction()` (#12199)

To handle cancellation, we ensure that `after_callback`s and
`exception_callback`s are always run, since the transaction will
complete on another thread regardless of cancellation.

We also wait until everything is done before releasing the
`CancelledError`, so that logging contexts won't get used after they
have been finished.

Signed-off-by: Sean Quah <seanq@element.io>
dmr/debug-check-deps
Sean Quah 2022-03-16 15:07:41 +00:00 committed by GitHub
parent fc9bd620ce
commit 6121056740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 24 deletions

1
changelog.d/12199.misc Normal file
View File

@ -0,0 +1 @@
Handle cancellation in `DatabasePool.runInteraction()`.

View File

@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal from typing_extensions import Literal
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -732,34 +734,45 @@ class DatabasePool:
Returns: Returns:
The result of func The result of func
""" """
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context(): async def _runInteraction() -> R:
logger.warning("Starting db txn '%s' from sentinel context", desc) after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
try: if not current_context():
with opentracing.start_active_span(f"db.{desc}"): logger.warning("Starting db txn '%s' from sentinel context", desc)
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
for after_callback, after_args, after_kwargs in after_callbacks: try:
after_callback(*after_args, **after_kwargs) with opentracing.start_active_span(f"db.{desc}"):
except Exception: result = await self.runWithConnection(
for after_callback, after_args, after_kwargs in exception_callbacks: self.new_transaction,
after_callback(*after_args, **after_kwargs) desc,
raise after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
return cast(R, result) for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection( async def runWithConnection(
self, self,

View File

@ -15,6 +15,8 @@
from typing import Callable, Tuple from typing import Callable, Tuple
from unittest.mock import Mock, call from unittest.mock import Mock, call
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer from synapse.server import HomeServer
@ -124,3 +126,59 @@ class CallbacksTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(after_callback.call_count, 2) # no additional calls self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called() exception_callback.assert_not_called()
class CancellationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
# Simulate a retryable failure on every attempt.
raise self.db_pool.engine.module.OperationalError()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls