Kill off `_PushHTTPChannel`. (#9878)
First of all, a fixup to `FakeChannel` which is needed to make it work with the default HTTP channel implementation. Secondly, it looks like we no longer need `_PushHTTPChannel`, because as of #8013, the producer that gets attached to the `HTTPChannel` is now an `IPushProducer`. This is good, because it means we can remove a whole load of test-specific boilerplate which causes variation between tests and production.pull/9887/head
parent
695b73c861
commit
84936e2264
|
@ -0,0 +1 @@
|
|||
Remove redundant `_PushHTTPChannel` test class.
|
|
@ -12,14 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request, Site
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
|
|||
ServerReplicationStreamProtocol,
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
||||
channel = self.site.buildProtocol(None)
|
||||
|
||||
# hook into the channel's request factory so that we can keep a record
|
||||
# of the requests
|
||||
requests: List[SynapseRequest] = []
|
||||
real_request_factory = channel.requestFactory
|
||||
|
||||
def request_factory(*args, **kwargs):
|
||||
request = real_request_factory(*args, **kwargs)
|
||||
requests.append(request)
|
||||
return request
|
||||
|
||||
channel.requestFactory = request_factory
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
server_to_client_transport.loseConnection()
|
||||
client_to_server_transport.loseConnection()
|
||||
|
||||
return channel.request
|
||||
# there should have been exactly one request
|
||||
self.assertEqual(len(requests), 1)
|
||||
|
||||
return requests[0]
|
||||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
|
@ -387,7 +397,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
||||
channel = self._hs_to_site[hs].buildProtocol(None)
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -445,112 +455,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
|
||||
class _PushHTTPChannel(HTTPChannel):
|
||||
"""A HTTPChannel that wraps pull producers to push producers.
|
||||
|
||||
This is a hack to get around the fact that HTTPChannel transparently wraps a
|
||||
pull producer (which is what Synapse uses to reply to requests) with
|
||||
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
|
||||
uses the standard reactor rather than letting us use our test reactor, which
|
||||
makes it very hard to test.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
||||
):
|
||||
super().__init__()
|
||||
self.reactor = reactor
|
||||
self.requestFactory = request_factory
|
||||
self.site = site
|
||||
|
||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# Convert pull producers to push producer.
|
||||
if not streaming:
|
||||
self._pull_to_push_producer = _PullToPushProducer(
|
||||
self.reactor, producer, self
|
||||
)
|
||||
producer = self._pull_to_push_producer
|
||||
|
||||
super().registerProducer(producer, True)
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self._pull_to_push_producer:
|
||||
# We need to manually stop the _PullToPushProducer.
|
||||
self._pull_to_push_producer.stop()
|
||||
|
||||
def checkPersistence(self, request, version):
|
||||
"""Check whether the connection can be re-used"""
|
||||
# We hijack this to always say no for ease of wiring stuff up in
|
||||
# `handle_http_replication_attempt`.
|
||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||
return False
|
||||
|
||||
def requestDone(self, request):
|
||||
# Store the request for inspection.
|
||||
self.request = request
|
||||
super().requestDone(request)
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer."""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
|
||||
):
|
||||
self._clock = Clock(reactor)
|
||||
self._producer = producer
|
||||
self._consumer = consumer
|
||||
|
||||
# While running we use a looping call with a zero delay to call
|
||||
# resumeProducing on given producer.
|
||||
self._looping_call = None # type: Optional[LoopingCall]
|
||||
|
||||
# We start writing next reactor tick.
|
||||
self._start_loop()
|
||||
|
||||
def _start_loop(self):
|
||||
"""Start the looping call to"""
|
||||
|
||||
if not self._looping_call:
|
||||
# Start a looping call which runs every tick.
|
||||
self._looping_call = self._clock.looping_call(self._run_once, 0)
|
||||
|
||||
def stop(self):
|
||||
"""Stops calling resumeProducing."""
|
||||
if self._looping_call:
|
||||
self._looping_call.stop()
|
||||
self._looping_call = None
|
||||
|
||||
def pauseProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self.stop()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self._start_loop()
|
||||
|
||||
def stopProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self.stop()
|
||||
self._producer.stopProducing()
|
||||
|
||||
def _run_once(self):
|
||||
"""Calls resumeProducing on producer once."""
|
||||
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except Exception:
|
||||
logger.exception("Failed to call resumeProducing")
|
||||
try:
|
||||
self._consumer.unregisterProducer()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.stopProducing()
|
||||
|
||||
|
||||
class FakeRedisPubSubServer:
|
||||
"""A fake Redis server for pub/sub."""
|
||||
|
||||
|
|
|
@ -603,12 +603,6 @@ class FakeTransport:
|
|||
if self.disconnected:
|
||||
return
|
||||
|
||||
if not hasattr(self.other, "transport"):
|
||||
# the other has no transport yet; reschedule
|
||||
if self.autoflush:
|
||||
self._reactor.callLater(0.0, self.flush)
|
||||
return
|
||||
|
||||
if maxbytes is not None:
|
||||
to_write = self.buffer[:maxbytes]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue