Support for routing outbound HTTP requests via a proxy (#6239)
The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy. The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`. The proxy will then be used for * push * url previews * phone-home stats * recaptcha validation * CAS auth validation It will *not* be used for: * Application Services * Identity servers * Outbound federation * In worker configurations, connections from workers to masters Fixes #4198.pull/6317/head
parent
fe1f2b4520
commit
1cb84c6486
|
@ -0,0 +1 @@
|
|||
Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
|
|
@ -565,7 +565,7 @@ def run(hs):
|
|||
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
||||
)
|
||||
try:
|
||||
yield hs.get_simple_http_client().put_json(
|
||||
yield hs.get_proxied_http_client().put_json(
|
||||
hs.config.report_stats_endpoint, stats
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
|||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
self._enabled = bool(hs.config.recaptcha_private_key)
|
||||
self._http_client = hs.get_simple_http_client()
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._url = hs.config.recaptcha_siteverify_api
|
||||
self._secret = hs.config.recaptcha_private_key
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ from synapse.http import (
|
|||
cancelled_to_request_timed_out_error,
|
||||
redact_uri,
|
||||
)
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
|
@ -183,7 +184,15 @@ class SimpleHttpClient(object):
|
|||
using HTTP in Matrix
|
||||
"""
|
||||
|
||||
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
|
||||
def __init__(
|
||||
self,
|
||||
hs,
|
||||
treq_args={},
|
||||
ip_whitelist=None,
|
||||
ip_blacklist=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
|
@ -192,6 +201,8 @@ class SimpleHttpClient(object):
|
|||
we may not request.
|
||||
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
||||
request if it were otherwise caught in a blacklist.
|
||||
http_proxy (bytes): proxy server to use for http connections. host[:port]
|
||||
https_proxy (bytes): proxy server to use for https connections. host[:port]
|
||||
"""
|
||||
self.hs = hs
|
||||
|
||||
|
@ -236,11 +247,13 @@ class SimpleHttpClient(object):
|
|||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = Agent(
|
||||
self.agent = ProxyAgent(
|
||||
self.reactor,
|
||||
connectTimeout=15,
|
||||
contextFactory=self.hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
http_proxy=http_proxy,
|
||||
https_proxy=https_proxy,
|
||||
)
|
||||
|
||||
if self._ip_blacklist:
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.internet.protocol import connectionDone
|
||||
from twisted.web import http
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyConnectError(ConnectError):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class HTTPConnectProxyEndpoint(object):
|
||||
"""An Endpoint implementation which will send a CONNECT request to an http proxy
|
||||
|
||||
Wraps an existing HostnameEndpoint for the proxy.
|
||||
|
||||
When we get the connect() request from the connection pool (via the TLS wrapper),
|
||||
we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
|
||||
CONNECT request. Once that completes, we invoke the protocolFactory which was passed
|
||||
in.
|
||||
|
||||
Args:
|
||||
reactor: the Twisted reactor to use for the connection
|
||||
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
|
||||
proxy
|
||||
host (bytes): hostname that we want to CONNECT to
|
||||
port (int): port that we want to connect to
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, proxy_endpoint, host, port):
|
||||
self._reactor = reactor
|
||||
self._proxy_endpoint = proxy_endpoint
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
def __repr__(self):
|
||||
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
|
||||
d = self._proxy_endpoint.connect(f)
|
||||
# once the tcp socket connects successfully, we need to wait for the
|
||||
# CONNECT to complete.
|
||||
d.addCallback(lambda conn: f.on_connection)
|
||||
return d
|
||||
|
||||
|
||||
class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||
"""ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
|
||||
|
||||
Once the CONNECT completes, invokes the original ClientFactory to build the
|
||||
HTTP Protocol object and run the rest of the connection.
|
||||
|
||||
Args:
|
||||
dst_host (bytes): hostname that we want to CONNECT to
|
||||
dst_port (int): port that we want to connect to
|
||||
wrapped_factory (protocol.ClientFactory): The original Factory
|
||||
"""
|
||||
|
||||
def __init__(self, dst_host, dst_port, wrapped_factory):
|
||||
self.dst_host = dst_host
|
||||
self.dst_port = dst_port
|
||||
self.wrapped_factory = wrapped_factory
|
||||
self.on_connection = defer.Deferred()
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
return self.wrapped_factory.startedConnecting(connector)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||
|
||||
return HTTPConnectProtocol(
|
||||
self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
|
||||
)
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
logger.debug("Connection to proxy failed: %s", reason)
|
||||
if not self.on_connection.called:
|
||||
self.on_connection.errback(reason)
|
||||
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
logger.debug("Connection to proxy lost: %s", reason)
|
||||
if not self.on_connection.called:
|
||||
self.on_connection.errback(reason)
|
||||
return self.wrapped_factory.clientConnectionLost(connector, reason)
|
||||
|
||||
|
||||
class HTTPConnectProtocol(protocol.Protocol):
|
||||
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
|
||||
|
||||
Args:
|
||||
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||
to put in the CONNECT request
|
||||
|
||||
port (int): The original HTTP(s) port to put in the CONNECT request
|
||||
|
||||
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
|
||||
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
|
||||
|
||||
connected_deferred (Deferred): a Deferred which will be callbacked with
|
||||
wrapped_protocol when the CONNECT completes
|
||||
"""
|
||||
|
||||
def __init__(self, host, port, wrapped_protocol, connected_deferred):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.wrapped_protocol = wrapped_protocol
|
||||
self.connected_deferred = connected_deferred
|
||||
self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
|
||||
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||
|
||||
def connectionMade(self):
|
||||
self.http_setup_client.makeConnection(self.transport)
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
if self.wrapped_protocol.connected:
|
||||
self.wrapped_protocol.connectionLost(reason)
|
||||
|
||||
self.http_setup_client.connectionLost(reason)
|
||||
|
||||
if not self.connected_deferred.called:
|
||||
self.connected_deferred.errback(reason)
|
||||
|
||||
def proxyConnected(self, _):
|
||||
self.wrapped_protocol.makeConnection(self.transport)
|
||||
|
||||
self.connected_deferred.callback(self.wrapped_protocol)
|
||||
|
||||
# Get any pending data from the http buf and forward it to the original protocol
|
||||
buf = self.http_setup_client.clearLineBuffer()
|
||||
if buf:
|
||||
self.wrapped_protocol.dataReceived(buf)
|
||||
|
||||
def dataReceived(self, data):
|
||||
# if we've set up the HTTP protocol, we can send the data there
|
||||
if self.wrapped_protocol.connected:
|
||||
return self.wrapped_protocol.dataReceived(data)
|
||||
|
||||
# otherwise, we must still be setting up the connection: send the data to the
|
||||
# setup client
|
||||
return self.http_setup_client.dataReceived(data)
|
||||
|
||||
|
||||
class HTTPConnectSetupClient(http.HTTPClient):
|
||||
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
|
||||
|
||||
Args:
|
||||
host (bytes): The hostname to send in the CONNECT message
|
||||
port (int): The port to send in the CONNECT message
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.on_connected = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
logger.debug("Connected to proxy, sending CONNECT")
|
||||
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||
self.endHeaders()
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||
if status != b"200":
|
||||
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|
||||
|
||||
def handleEndHeaders(self):
|
||||
logger.debug("End Headers")
|
||||
self.on_connected.callback(None)
|
||||
|
||||
def handleResponse(self, body):
|
||||
pass
|
|
@ -0,0 +1,195 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||
from twisted.web.error import SchemeNotSupported
|
||||
from twisted.web.iweb import IAgent
|
||||
|
||||
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
|
||||
|
||||
|
||||
@implementer(IAgent)
|
||||
class ProxyAgent(_AgentBase):
|
||||
"""An Agent implementation which will use an HTTP proxy if one was requested
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor to place outgoing
|
||||
connections.
|
||||
|
||||
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
|
||||
verification parameters of OpenSSL. The default is to use a
|
||||
`BrowserLikePolicyForHTTPS`, so unless you have special
|
||||
requirements you can leave this as-is.
|
||||
|
||||
connectTimeout (float): The amount of time that this Agent will wait
|
||||
for the peer to accept a connection.
|
||||
|
||||
bindAddress (bytes): The local address for client sockets to bind to.
|
||||
|
||||
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||
non-persistent pool instance will be created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reactor,
|
||||
contextFactory=BrowserLikePolicyForHTTPS(),
|
||||
connectTimeout=None,
|
||||
bindAddress=None,
|
||||
pool=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
self._endpoint_kwargs = {}
|
||||
if connectTimeout is not None:
|
||||
self._endpoint_kwargs["timeout"] = connectTimeout
|
||||
if bindAddress is not None:
|
||||
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||
|
||||
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||
http_proxy, reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self.https_proxy_endpoint = _http_proxy_endpoint(
|
||||
https_proxy, reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self._policy_for_https = contextFactory
|
||||
self._reactor = reactor
|
||||
|
||||
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||
"""
|
||||
Issue a request to the server indicated by the given uri.
|
||||
|
||||
Supports `http` and `https` schemes.
|
||||
|
||||
An existing connection from the connection pool may be used or a new one may be
|
||||
created.
|
||||
|
||||
See also: twisted.web.iweb.IAgent.request
|
||||
|
||||
Args:
|
||||
method (bytes): The request method to use, such as `GET`, `POST`, etc
|
||||
|
||||
uri (bytes): The location of the resource to request.
|
||||
|
||||
headers (Headers|None): Extra headers to send with the request
|
||||
|
||||
bodyProducer (IBodyProducer|None): An object which can generate bytes to
|
||||
make up the body of this request (for example, the properly encoded
|
||||
contents of a file for a file upload). Or, None if the request is to
|
||||
have no body.
|
||||
|
||||
Returns:
|
||||
Deferred[IResponse]: completes when the header of the response has
|
||||
been received (regardless of the response status code).
|
||||
"""
|
||||
uri = uri.strip()
|
||||
if not _VALID_URI.match(uri):
|
||||
raise ValueError("Invalid URI {!r}".format(uri))
|
||||
|
||||
parsed_uri = URI.fromBytes(uri)
|
||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
request_path = parsed_uri.originForm
|
||||
|
||||
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||
# Cache *all* connections under the same key, since we are only
|
||||
# connecting to a single destination, the proxy:
|
||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||
endpoint = self.http_proxy_endpoint
|
||||
request_path = uri
|
||||
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||
endpoint = HTTPConnectProxyEndpoint(
|
||||
self._reactor,
|
||||
self.https_proxy_endpoint,
|
||||
parsed_uri.host,
|
||||
parsed_uri.port,
|
||||
)
|
||||
else:
|
||||
# not using a proxy
|
||||
endpoint = HostnameEndpoint(
|
||||
self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
logger.debug("Requesting %s via %s", uri, endpoint)
|
||||
|
||||
if parsed_uri.scheme == b"https":
|
||||
tls_connection_creator = self._policy_for_https.creatorForNetloc(
|
||||
parsed_uri.host, parsed_uri.port
|
||||
)
|
||||
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
|
||||
elif parsed_uri.scheme == b"http":
|
||||
pass
|
||||
else:
|
||||
return defer.fail(
|
||||
Failure(
|
||||
SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
|
||||
)
|
||||
)
|
||||
|
||||
return self._requestWithEndpoint(
|
||||
pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
|
||||
)
|
||||
|
||||
|
||||
def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||
|
||||
Args:
|
||||
proxy (bytes|None): the proxy setting
|
||||
reactor: reactor to be used to connect to the proxy
|
||||
kwargs: other args to be passed to HostnameEndpoint
|
||||
|
||||
Returns:
|
||||
interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
|
||||
or None
|
||||
"""
|
||||
if proxy is None:
|
||||
return None
|
||||
|
||||
# currently we only support hostname:port. Some apps also support
|
||||
# protocol://<host>[:port], which allows a way of requiring a TLS connection to the
|
||||
# proxy.
|
||||
|
||||
host, port = parse_host_port(proxy, default_port=1080)
|
||||
return HostnameEndpoint(reactor, host, port, **kwargs)
|
||||
|
||||
|
||||
def parse_host_port(hostport, default_port=None):
|
||||
# could have sworn we had one of these somewhere else...
|
||||
if b":" in hostport:
|
||||
host, port = hostport.rsplit(b":", 1)
|
||||
try:
|
||||
port = int(port)
|
||||
return host, port
|
||||
except ValueError:
|
||||
# the thing after the : wasn't a valid port; presumably this is an
|
||||
# IPv6 address.
|
||||
pass
|
||||
|
||||
return hostport, default_port
|
|
@ -103,7 +103,7 @@ class HttpPusher(object):
|
|||
if "url" not in self.data:
|
||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||
self.url = self.data["url"]
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.http_client = hs.get_proxied_http_client()
|
||||
self.data_minus_url = {}
|
||||
self.data_minus_url.update(self.data)
|
||||
del self.data_minus_url["url"]
|
||||
|
|
|
@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet):
|
|||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||
self._http_client = hs.get_simple_http_client()
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
|
|
@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
|
|||
treq_args={"browser_like_redirects": True},
|
||||
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
||||
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
||||
http_proxy=os.getenv("http_proxy"),
|
||||
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||
)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
# Imports required for the default HomeServer() implementation
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.mail.smtp import sendmail
|
||||
|
@ -168,6 +169,7 @@ class HomeServer(object):
|
|||
"filtering",
|
||||
"http_client_context_factory",
|
||||
"simple_http_client",
|
||||
"proxied_http_client",
|
||||
"media_repository",
|
||||
"media_repository_resource",
|
||||
"federation_transport_client",
|
||||
|
@ -311,6 +313,13 @@ class HomeServer(object):
|
|||
def build_simple_http_client(self):
|
||||
return SimpleHttpClient(self)
|
||||
|
||||
def build_proxied_http_client(self):
|
||||
return SimpleHttpClient(
|
||||
self,
|
||||
http_proxy=os.getenv("http_proxy"),
|
||||
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||
)
|
||||
|
||||
def build_room_creation_handler(self):
|
||||
return RoomCreationHandler(self)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import synapse.handlers.message
|
|||
import synapse.handlers.room
|
||||
import synapse.handlers.room_member
|
||||
import synapse.handlers.set_password
|
||||
import synapse.http.client
|
||||
import synapse.rest.media.v1.media_repository
|
||||
import synapse.server_notices.server_notices_manager
|
||||
import synapse.server_notices.server_notices_sender
|
||||
|
@ -38,6 +39,14 @@ class HomeServer(object):
|
|||
pass
|
||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||
pass
|
||||
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||
or support any HTTP_PROXY settings"""
|
||||
pass
|
||||
def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||
but does support HTTP_PROXY settings"""
|
||||
pass
|
||||
def get_deactivate_account_handler(
|
||||
self,
|
||||
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||
|
|
|
@ -20,6 +20,23 @@ from zope.interface import implementer
|
|||
from OpenSSL import SSL
|
||||
from OpenSSL.SSL import Connection
|
||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||
|
||||
|
||||
def get_test_https_policy():
|
||||
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||
|
||||
Returns:
|
||||
IPolicyForHTTPS
|
||||
"""
|
||||
ca_file = get_test_ca_cert_file()
|
||||
with open(ca_file) as stream:
|
||||
content = stream.read()
|
||||
cert = Certificate.loadPEM(content)
|
||||
trust_root = trustRootFromCertificates([cert])
|
||||
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||
|
||||
|
||||
def get_test_ca_cert_file():
|
||||
|
|
|
@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
|||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||
)
|
||||
|
||||
# grab a hold of the TLS connection, in case it gets torn down
|
||||
server_tls_connection = server_tls_protocol._tlsConnection
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
http_protocol = server_tls_protocol.wrappedProtocol
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
# check the SNI
|
||||
server_name = server_tls_protocol._tlsConnection.get_servername()
|
||||
server_name = server_tls_connection.get_servername()
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
return server_tls_protocol.wrappedProtocol
|
||||
return http_protocol
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_get_request(self, uri):
|
||||
|
|
|
@ -0,0 +1,334 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import treq
|
||||
|
||||
from twisted.internet import interfaces # noqa: F401
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
|
||||
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTPFactory = Factory.forProtocol(HTTPChannel)
|
||||
|
||||
|
||||
class MatrixFederationAgentTests(TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
||||
def _make_connection(
|
||||
self, client_factory, server_factory, ssl=False, expected_sni=None
|
||||
):
|
||||
"""Builds a test server, and completes the outgoing client connection
|
||||
|
||||
Args:
|
||||
client_factory (interfaces.IProtocolFactory): the the factory that the
|
||||
application is trying to use to make the outbound connection. We will
|
||||
invoke it to build the client Protocol
|
||||
|
||||
server_factory (interfaces.IProtocolFactory): a factory to build the
|
||||
server-side protocol
|
||||
|
||||
ssl (bool): If true, we will expect an ssl connection and wrap
|
||||
server_factory with a TLSMemoryBIOFactory
|
||||
|
||||
expected_sni (bytes|None): the expected SNI value
|
||||
|
||||
Returns:
|
||||
IProtocol: the server Protocol returned by server_factory
|
||||
"""
|
||||
if ssl:
|
||||
server_factory = _wrap_server_factory_for_tls(server_factory)
|
||||
|
||||
server_protocol = server_factory.buildProtocol(None)
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol,
|
||||
# and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(
|
||||
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||
)
|
||||
|
||||
# tell the server protocol to send its stuff back to the client, too
|
||||
server_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, server_protocol)
|
||||
)
|
||||
|
||||
if ssl:
|
||||
http_protocol = server_protocol.wrappedProtocol
|
||||
tls_connection = server_protocol._tlsConnection
|
||||
else:
|
||||
http_protocol = server_protocol
|
||||
tls_connection = None
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing (if needed)
|
||||
self.reactor.advance(0)
|
||||
|
||||
if expected_sni is not None:
|
||||
server_name = tls_connection.get_servername()
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
return http_protocol
|
||||
|
||||
def test_http_request(self):
|
||||
agent = ProxyAgent(self.reactor)
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 80)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request(self):
|
||||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"https://test.com/abc")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 443)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
_get_test_protocol_factory(),
|
||||
ssl=True,
|
||||
expected_sni=b"test.com",
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_http_request_via_proxy(self):
|
||||
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.5")
|
||||
self.assertEqual(port, 8888)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"http://test.com")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request_via_proxy(self):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
https_proxy=b"proxy.com",
|
||||
)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
d = agent.request(b"GET", b"https://test.com/abc")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.5")
|
||||
self.assertEqual(port, 1080)
|
||||
|
||||
# make a test HTTP server, and wire up the client
|
||||
proxy_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# fish the transports back out so that we can do the old switcheroo
|
||||
s2c_transport = proxy_server.transport
|
||||
client_protocol = s2c_transport.other
|
||||
c2s_transport = client_protocol.transport
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending CONNECT request
|
||||
self.assertEqual(len(proxy_server.requests), 1)
|
||||
|
||||
request = proxy_server.requests[0]
|
||||
self.assertEqual(request.method, b"CONNECT")
|
||||
self.assertEqual(request.path, b"test.com:443")
|
||||
|
||||
# tell the proxy server not to close the connection
|
||||
proxy_server.persistent = True
|
||||
|
||||
# this just stops the http Request trying to do a chunked response
|
||||
# request.setHeader(b"Content-Length", b"0")
|
||||
request.finish()
|
||||
|
||||
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
||||
ssl_protocol = ssl_factory.buildProtocol(None)
|
||||
http_server = ssl_protocol.wrappedProtocol
|
||||
|
||||
ssl_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
||||
)
|
||||
c2s_transport.other = ssl_protocol
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
server_name = ssl_protocol._tlsConnection.get_servername()
|
||||
expected_sni = b"test.com"
|
||||
self.assertEqual(
|
||||
server_name,
|
||||
expected_sni,
|
||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||
)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
|
||||
def _wrap_server_factory_for_tls(factory, sanlist=None):
|
||||
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||
|
||||
The resultant factory will create a TLS server which presents a certificate
|
||||
signed by our test CA, valid for the domains in `sanlist`
|
||||
|
||||
Args:
|
||||
factory (interfaces.IProtocolFactory): protocol factory to wrap
|
||||
sanlist (iterable[bytes]): list of domains the cert should be valid for
|
||||
|
||||
Returns:
|
||||
interfaces.IProtocolFactory
|
||||
"""
|
||||
if sanlist is None:
|
||||
sanlist = [b"DNS:test.com"]
|
||||
|
||||
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||
return TLSMemoryBIOFactory(
|
||||
connection_creator, isClient=False, wrappedFactory=factory
|
||||
)
|
||||
|
||||
|
||||
def _get_test_protocol_factory():
|
||||
"""Get a protocol Factory which will build an HTTPChannel
|
||||
|
||||
Returns:
|
||||
interfaces.IProtocolFactory
|
||||
"""
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
return server_factory
|
||||
|
||||
|
||||
def _log_request(request):
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
|
@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
config = self.default_config()
|
||||
config["start_pushers"] = True
|
||||
|
||||
hs = self.setup_test_homeserver(config=config, simple_http_client=m)
|
||||
hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
|
||||
|
||||
return hs
|
||||
|
||||
|
|
|
@ -395,11 +395,24 @@ class FakeTransport(object):
|
|||
self.disconnecting = True
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(reason)
|
||||
self.disconnected = True
|
||||
|
||||
# if we still have data to write, delay until that is done
|
||||
if self.buffer:
|
||||
logger.info(
|
||||
"FakeTransport: Delaying disconnect until buffer is flushed"
|
||||
)
|
||||
else:
|
||||
self.disconnected = True
|
||||
|
||||
def abortConnection(self):
|
||||
logger.info("FakeTransport: abortConnection()")
|
||||
self.loseConnection()
|
||||
|
||||
if not self.disconnecting:
|
||||
self.disconnecting = True
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(None)
|
||||
|
||||
self.disconnected = True
|
||||
|
||||
def pauseProducing(self):
|
||||
if not self.producer:
|
||||
|
@ -430,6 +443,9 @@ class FakeTransport(object):
|
|||
self._reactor.callLater(0.0, _produce)
|
||||
|
||||
def write(self, byt):
|
||||
if self.disconnecting:
|
||||
raise Exception("Writing to disconnecting FakeTransport")
|
||||
|
||||
self.buffer = self.buffer + byt
|
||||
|
||||
# always actually do the write asynchronously. Some protocols (notably the
|
||||
|
@ -474,6 +490,10 @@ class FakeTransport(object):
|
|||
if self.buffer and self.autoflush:
|
||||
self._reactor.callLater(0.0, self.flush)
|
||||
|
||||
if not self.buffer and self.disconnecting:
|
||||
logger.info("FakeTransport: Buffer now empty, completing disconnect")
|
||||
self.disconnected = True
|
||||
|
||||
|
||||
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue