Kill off `_PushHTTPChannel`. (#9878)
First of all, a fixup to `FakeChannel` which is needed to make it work with the default HTTP channel implementation. Secondly, it looks like we no longer need `_PushHTTPChannel`, because as of #8013, the producer that gets attached to the `HTTPChannel` is now an `IPushProducer`. This is good, because it means we can remove a whole load of test-specific boilerplate which causes variation between tests and production.pull/9887/head
parent
695b73c861
commit
84936e2264
|
@ -0,0 +1 @@
|
||||||
|
Remove redundant `_PushHTTPChannel` test class.
|
|
@ -12,14 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
from twisted.internet.task import LoopingCall
|
|
||||||
from twisted.web.http import HTTPChannel
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Request, Site
|
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
|
||||||
ServerReplicationStreamProtocol,
|
ServerReplicationStreamProtocol,
|
||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeTransport
|
from tests.server import FakeTransport
|
||||||
|
@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
channel = self.site.buildProtocol(None)
|
||||||
|
|
||||||
|
# hook into the channel's request factory so that we can keep a record
|
||||||
|
# of the requests
|
||||||
|
requests: List[SynapseRequest] = []
|
||||||
|
real_request_factory = channel.requestFactory
|
||||||
|
|
||||||
|
def request_factory(*args, **kwargs):
|
||||||
|
request = real_request_factory(*args, **kwargs)
|
||||||
|
requests.append(request)
|
||||||
|
return request
|
||||||
|
|
||||||
|
channel.requestFactory = request_factory
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
server_to_client_transport.loseConnection()
|
server_to_client_transport.loseConnection()
|
||||||
client_to_server_transport.loseConnection()
|
client_to_server_transport.loseConnection()
|
||||||
|
|
||||||
return channel.request
|
# there should have been exactly one request
|
||||||
|
self.assertEqual(len(requests), 1)
|
||||||
|
|
||||||
|
return requests[0]
|
||||||
|
|
||||||
def assert_request_is_get_repl_stream_updates(
|
def assert_request_is_get_repl_stream_updates(
|
||||||
self, request: SynapseRequest, stream_name: str
|
self, request: SynapseRequest, stream_name: str
|
||||||
|
@ -387,7 +397,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
channel = self._hs_to_site[hs].buildProtocol(None)
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
@ -445,112 +455,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
||||||
|
|
||||||
class _PushHTTPChannel(HTTPChannel):
|
|
||||||
"""A HTTPChannel that wraps pull producers to push producers.
|
|
||||||
|
|
||||||
This is a hack to get around the fact that HTTPChannel transparently wraps a
|
|
||||||
pull producer (which is what Synapse uses to reply to requests) with
|
|
||||||
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
|
|
||||||
uses the standard reactor rather than letting us use our test reactor, which
|
|
||||||
makes it very hard to test.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.reactor = reactor
|
|
||||||
self.requestFactory = request_factory
|
|
||||||
self.site = site
|
|
||||||
|
|
||||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
|
||||||
# Convert pull producers to push producer.
|
|
||||||
if not streaming:
|
|
||||||
self._pull_to_push_producer = _PullToPushProducer(
|
|
||||||
self.reactor, producer, self
|
|
||||||
)
|
|
||||||
producer = self._pull_to_push_producer
|
|
||||||
|
|
||||||
super().registerProducer(producer, True)
|
|
||||||
|
|
||||||
def unregisterProducer(self):
|
|
||||||
if self._pull_to_push_producer:
|
|
||||||
# We need to manually stop the _PullToPushProducer.
|
|
||||||
self._pull_to_push_producer.stop()
|
|
||||||
|
|
||||||
def checkPersistence(self, request, version):
|
|
||||||
"""Check whether the connection can be re-used"""
|
|
||||||
# We hijack this to always say no for ease of wiring stuff up in
|
|
||||||
# `handle_http_replication_attempt`.
|
|
||||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
|
||||||
return False
|
|
||||||
|
|
||||||
def requestDone(self, request):
|
|
||||||
# Store the request for inspection.
|
|
||||||
self.request = request
|
|
||||||
super().requestDone(request)
|
|
||||||
|
|
||||||
|
|
||||||
class _PullToPushProducer:
|
|
||||||
"""A push producer that wraps a pull producer."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
|
|
||||||
):
|
|
||||||
self._clock = Clock(reactor)
|
|
||||||
self._producer = producer
|
|
||||||
self._consumer = consumer
|
|
||||||
|
|
||||||
# While running we use a looping call with a zero delay to call
|
|
||||||
# resumeProducing on given producer.
|
|
||||||
self._looping_call = None # type: Optional[LoopingCall]
|
|
||||||
|
|
||||||
# We start writing next reactor tick.
|
|
||||||
self._start_loop()
|
|
||||||
|
|
||||||
def _start_loop(self):
|
|
||||||
"""Start the looping call to"""
|
|
||||||
|
|
||||||
if not self._looping_call:
|
|
||||||
# Start a looping call which runs every tick.
|
|
||||||
self._looping_call = self._clock.looping_call(self._run_once, 0)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stops calling resumeProducing."""
|
|
||||||
if self._looping_call:
|
|
||||||
self._looping_call.stop()
|
|
||||||
self._looping_call = None
|
|
||||||
|
|
||||||
def pauseProducing(self):
|
|
||||||
"""Implements IPushProducer"""
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def resumeProducing(self):
|
|
||||||
"""Implements IPushProducer"""
|
|
||||||
self._start_loop()
|
|
||||||
|
|
||||||
def stopProducing(self):
|
|
||||||
"""Implements IPushProducer"""
|
|
||||||
self.stop()
|
|
||||||
self._producer.stopProducing()
|
|
||||||
|
|
||||||
def _run_once(self):
|
|
||||||
"""Calls resumeProducing on producer once."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._producer.resumeProducing()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to call resumeProducing")
|
|
||||||
try:
|
|
||||||
self._consumer.unregisterProducer()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.stopProducing()
|
|
||||||
|
|
||||||
|
|
||||||
class FakeRedisPubSubServer:
|
class FakeRedisPubSubServer:
|
||||||
"""A fake Redis server for pub/sub."""
|
"""A fake Redis server for pub/sub."""
|
||||||
|
|
||||||
|
|
|
@ -603,12 +603,6 @@ class FakeTransport:
|
||||||
if self.disconnected:
|
if self.disconnected:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not hasattr(self.other, "transport"):
|
|
||||||
# the other has no transport yet; reschedule
|
|
||||||
if self.autoflush:
|
|
||||||
self._reactor.callLater(0.0, self.flush)
|
|
||||||
return
|
|
||||||
|
|
||||||
if maxbytes is not None:
|
if maxbytes is not None:
|
||||||
to_write = self.buffer[:maxbytes]
|
to_write = self.buffer[:maxbytes]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue