Clean up the test code for client disconnections (#12929)

* Reword failure message about `await_result=False`
* Use `reactor.advance()` instead of `reactor.pump()`
* Raise `AssertionError`s ourselves
* Un-instance method `_test_disconnect`
* Replace `ThreadedMemoryReactorClock` with `MemoryReactorClock`
pull/12984/head
Sean Quah 2022-06-07 18:17:32 +01:00 committed by GitHub
parent 586bfc6dc0
commit 3c1c40d843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 87 deletions

1
changelog.d/12929.misc Normal file
View File

@ -0,0 +1 @@
Clean up the test code for client disconnection.

View File

@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
class CancellableFederationServlet(BaseFederationServlet):
@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
return HTTPStatus.OK, {"result": True}
class BaseFederationServletCancellationTests(
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
):
class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""
skip = "`BaseFederationServlet` does not support cancellation yet."
@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,

View File

@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict
from tests import unittest
from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
from tests.server import FakeChannel, make_request
from tests.unittest import logcontext_clean
logger = logging.getLogger(__name__)
@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
class EndpointCancellationTestHelperMixin(unittest.TestCase):
"""Provides helper methods for testing cancellation of endpoints."""
def test_disconnect(
reactor: MemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.
def _test_disconnect(
self,
reactor: ThreadedMemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.
Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be cancelled,
`False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200` or
`499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK
Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be
cancelled, `False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200`
or `499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK
request = channel.request
self.assertFalse(
channel.is_finished(),
request = channel.request
if channel.is_finished():
raise AssertionError(
"Request finished before we could disconnect - "
"was `await_result=False` passed to `make_request`?",
"ensure `await_result=False` is passed to `make_request`.",
)
# We're about to disconnect the request. This also disconnects the channel, so
# we have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
# We're about to disconnect the request. This also disconnects the channel, so we
# have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
else:
respond_method = respond_with_json
with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())
if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
else:
respond_method = respond_with_json
respond_mock.assert_not_called()
with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())
# The handler is expected to run to completion.
reactor.advance(1.0)
respond_mock.assert_called_once()
if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
else:
respond_mock.assert_not_called()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
# The handler is expected to run to completion.
reactor.pump([1.0])
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
if code != expected_code:
raise AssertionError(
f"{code} != {expected_code} : "
"Request did not finish with the expected status code."
)
if request.code != expected_code:
raise AssertionError(
f"{request.code} != {expected_code} : "
"Request did not finish with the expected status code."
)
if body != expected_body:
raise AssertionError(
f"{body!r} != {expected_body!r} : "
"Request did not finish with the expected status code."
)
@logcontext_clean

View File

@ -30,7 +30,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
def make_request(content):
@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True}
class TestRestServletCancellation(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,

View File

@ -25,7 +25,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint):
@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True}
class ReplicationEndpointCancellationTestCase(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,

View File

@ -34,7 +34,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
return HTTPStatus.OK, b"ok"
class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""
def setUp(self):
@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
)
class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self):
@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)