Remove redundant WrappedConnection (#4409)
* Remove redundant WrappedConnection The matrix federation client uses an HTTP connection pool, which times out its idle HTTP connections, so there is no need for any of this business.pull/4421/head
parent
676cf2ee26
commit
de6888e7ce
|
@ -0,0 +1 @@
|
||||||
|
Remove redundant federation connection wrapping code
|
|
@ -140,82 +140,15 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
|
||||||
default_port = 8448
|
default_port = 8448
|
||||||
|
|
||||||
if port is None:
|
if port is None:
|
||||||
return _WrappingEndpointFac(SRVClientEndpoint(
|
return SRVClientEndpoint(
|
||||||
reactor, "matrix", domain, protocol="tcp",
|
reactor, "matrix", domain, protocol="tcp",
|
||||||
default_port=default_port, endpoint=transport_endpoint,
|
default_port=default_port, endpoint=transport_endpoint,
|
||||||
endpoint_kw_args=endpoint_kw_args
|
endpoint_kw_args=endpoint_kw_args
|
||||||
), reactor)
|
)
|
||||||
else:
|
else:
|
||||||
return _WrappingEndpointFac(transport_endpoint(
|
return transport_endpoint(
|
||||||
reactor, domain, port, **endpoint_kw_args
|
reactor, domain, port, **endpoint_kw_args
|
||||||
), reactor)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _WrappingEndpointFac(object):
|
|
||||||
def __init__(self, endpoint_fac, reactor):
|
|
||||||
self.endpoint_fac = endpoint_fac
|
|
||||||
self.reactor = reactor
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def connect(self, protocolFactory):
|
|
||||||
conn = yield self.endpoint_fac.connect(protocolFactory)
|
|
||||||
conn = _WrappedConnection(conn, self.reactor)
|
|
||||||
defer.returnValue(conn)
|
|
||||||
|
|
||||||
|
|
||||||
class _WrappedConnection(object):
|
|
||||||
"""Wraps a connection and calls abort on it if it hasn't seen any action
|
|
||||||
for 2.5-3 minutes.
|
|
||||||
"""
|
|
||||||
__slots__ = ["conn", "last_request"]
|
|
||||||
|
|
||||||
def __init__(self, conn, reactor):
|
|
||||||
object.__setattr__(self, "conn", conn)
|
|
||||||
object.__setattr__(self, "last_request", time.time())
|
|
||||||
self._reactor = reactor
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self.conn, name)
|
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
|
||||||
setattr(self.conn, name, value)
|
|
||||||
|
|
||||||
def _time_things_out_maybe(self):
|
|
||||||
# We use a slightly shorter timeout here just in case the callLater is
|
|
||||||
# triggered early. Paranoia ftw.
|
|
||||||
# TODO: Cancel the previous callLater rather than comparing time.time()?
|
|
||||||
if time.time() - self.last_request >= 2.5 * 60:
|
|
||||||
self.abort()
|
|
||||||
# Abort the underlying TLS connection. The abort() method calls
|
|
||||||
# loseConnection() on the TLS connection which tries to
|
|
||||||
# shutdown the connection cleanly. We call abortConnection()
|
|
||||||
# since that will promptly close the TLS connection.
|
|
||||||
#
|
|
||||||
# In Twisted >18.4; the TLS connection will be None if it has closed
|
|
||||||
# which will make abortConnection() throw. Check that the TLS connection
|
|
||||||
# is not None before trying to close it.
|
|
||||||
if self.transport.getHandle() is not None:
|
|
||||||
self.transport.abortConnection()
|
|
||||||
|
|
||||||
def request(self, request):
|
|
||||||
self.last_request = time.time()
|
|
||||||
|
|
||||||
# Time this connection out if we haven't send a request in the last
|
|
||||||
# N minutes
|
|
||||||
# TODO: Cancel the previous callLater?
|
|
||||||
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
|
|
||||||
|
|
||||||
d = self.conn.request(request)
|
|
||||||
|
|
||||||
def update_request_time(res):
|
|
||||||
self.last_request = time.time()
|
|
||||||
# TODO: Cancel the previous callLater?
|
|
||||||
self._reactor.callLater(3 * 60, self._time_things_out_maybe)
|
|
||||||
return res
|
|
||||||
|
|
||||||
d.addCallback(update_request_time)
|
|
||||||
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class SRVClientEndpoint(object):
|
class SRVClientEndpoint(object):
|
||||||
|
|
|
@ -321,23 +321,23 @@ class MatrixFederationHttpClient(object):
|
||||||
url_str,
|
url_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we don't want all the fancy cookie and redirect handling that
|
|
||||||
# treq.request gives: just use the raw Agent.
|
|
||||||
request_deferred = self.agent.request(
|
|
||||||
method_bytes,
|
|
||||||
url_bytes,
|
|
||||||
headers=Headers(headers_dict),
|
|
||||||
bodyProducer=producer,
|
|
||||||
)
|
|
||||||
|
|
||||||
request_deferred = timeout_deferred(
|
|
||||||
request_deferred,
|
|
||||||
timeout=_sec_timeout,
|
|
||||||
reactor=self.hs.get_reactor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Measure(self.clock, "outbound_request"):
|
with Measure(self.clock, "outbound_request"):
|
||||||
|
# we don't want all the fancy cookie and redirect handling
|
||||||
|
# that treq.request gives: just use the raw Agent.
|
||||||
|
request_deferred = self.agent.request(
|
||||||
|
method_bytes,
|
||||||
|
url_bytes,
|
||||||
|
headers=Headers(headers_dict),
|
||||||
|
bodyProducer=producer,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_deferred = timeout_deferred(
|
||||||
|
request_deferred,
|
||||||
|
timeout=_sec_timeout,
|
||||||
|
reactor=self.hs.get_reactor(),
|
||||||
|
)
|
||||||
|
|
||||||
response = yield make_deferred_yieldable(
|
response = yield make_deferred_yieldable(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import TimeoutError
|
from twisted.internet.defer import TimeoutError
|
||||||
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
|
||||||
|
from twisted.test.proto_helpers import StringTransport
|
||||||
from twisted.web.client import ResponseNeverReceived
|
from twisted.web.client import ResponseNeverReceived
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
|
|
||||||
def test_dns_error(self):
|
def test_dns_error(self):
|
||||||
"""
|
"""
|
||||||
If the DNS raising returns an error, it will bubble up.
|
If the DNS lookup returns an error, it will bubble up.
|
||||||
"""
|
"""
|
||||||
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
|
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -63,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
# Nothing happened yet
|
# Nothing happened yet
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Make sure treq is trying to connect
|
# Make sure treq is trying to connect
|
||||||
clients = self.reactor.tcpClients
|
clients = self.reactor.tcpClients
|
||||||
|
@ -72,7 +73,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.assertEqual(clients[0][1], 8008)
|
self.assertEqual(clients[0][1], 8008)
|
||||||
|
|
||||||
# Deferred is still without a result
|
# Deferred is still without a result
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Push by enough to time it out
|
# Push by enough to time it out
|
||||||
self.reactor.advance(10.5)
|
self.reactor.advance(10.5)
|
||||||
|
@ -94,7 +95,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
# Nothing happened yet
|
# Nothing happened yet
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Make sure treq is trying to connect
|
# Make sure treq is trying to connect
|
||||||
clients = self.reactor.tcpClients
|
clients = self.reactor.tcpClients
|
||||||
|
@ -107,7 +108,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
client.makeConnection(conn)
|
client.makeConnection(conn)
|
||||||
|
|
||||||
# Deferred is still without a result
|
# Deferred is still without a result
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Push by enough to time it out
|
# Push by enough to time it out
|
||||||
self.reactor.advance(10.5)
|
self.reactor.advance(10.5)
|
||||||
|
@ -135,7 +136,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
client.makeConnection(conn)
|
client.makeConnection(conn)
|
||||||
|
|
||||||
# Deferred does not have a result
|
# Deferred does not have a result
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Send it the HTTP response
|
# Send it the HTTP response
|
||||||
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
|
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
|
||||||
|
@ -159,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
client.makeConnection(conn)
|
client.makeConnection(conn)
|
||||||
|
|
||||||
# Deferred does not have a result
|
# Deferred does not have a result
|
||||||
self.assertFalse(d.called)
|
self.assertNoResult(d)
|
||||||
|
|
||||||
# Send it the HTTP response
|
# Send it the HTTP response
|
||||||
client.dataReceived(
|
client.dataReceived(
|
||||||
|
@ -195,3 +196,42 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
request = server.requests[0]
|
request = server.requests[0]
|
||||||
content = request.content.read()
|
content = request.content.read()
|
||||||
self.assertEqual(content, b'{"a":"b"}')
|
self.assertEqual(content, b'{"a":"b"}')
|
||||||
|
|
||||||
|
def test_closes_connection(self):
|
||||||
|
"""Check that the client closes unused HTTP connections"""
|
||||||
|
d = self.cl.get_json("testserv:8008", "foo/bar")
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# there should have been a call to connectTCP
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(_host, _port, factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
|
||||||
|
# complete the connection and wire it up to a fake transport
|
||||||
|
client = factory.buildProtocol(None)
|
||||||
|
conn = StringTransport()
|
||||||
|
client.makeConnection(conn)
|
||||||
|
|
||||||
|
# that should have made it send the request to the connection
|
||||||
|
self.assertRegex(conn.value(), b"^GET /foo/bar")
|
||||||
|
|
||||||
|
# Send the HTTP response
|
||||||
|
client.dataReceived(
|
||||||
|
b"HTTP/1.1 200 OK\r\n"
|
||||||
|
b"Content-Type: application/json\r\n"
|
||||||
|
b"Content-Length: 2\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
b"{}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We should get a successful response
|
||||||
|
r = self.successResultOf(d)
|
||||||
|
self.assertEqual(r, {})
|
||||||
|
|
||||||
|
self.assertFalse(conn.disconnecting)
|
||||||
|
|
||||||
|
# wait for a while
|
||||||
|
self.pump(120)
|
||||||
|
|
||||||
|
self.assertTrue(conn.disconnecting)
|
||||||
|
|
Loading…
Reference in New Issue