271 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			271 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2019 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import base64
 | 
						|
import logging
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
import attr
 | 
						|
from zope.interface import implementer
 | 
						|
 | 
						|
from twisted.internet import defer, protocol
 | 
						|
from twisted.internet.error import ConnectError
 | 
						|
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
 | 
						|
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
 | 
						|
from twisted.web import http
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class ProxyConnectError(ConnectError):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
@attr.s
 | 
						|
class ProxyCredentials:
 | 
						|
    username_password = attr.ib(type=bytes)
 | 
						|
 | 
						|
    def as_proxy_authorization_value(self) -> bytes:
 | 
						|
        """
 | 
						|
        Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A transformation of the authentication string the encoded value for
 | 
						|
            a Proxy-Authorization header.
 | 
						|
        """
 | 
						|
        # Encode as base64 and prepend the authorization type
 | 
						|
        return b"Basic " + base64.encodebytes(self.username_password)
 | 
						|
 | 
						|
 | 
						|
@implementer(IStreamClientEndpoint)
 | 
						|
class HTTPConnectProxyEndpoint:
 | 
						|
    """An Endpoint implementation which will send a CONNECT request to an http proxy
 | 
						|
 | 
						|
    Wraps an existing HostnameEndpoint for the proxy.
 | 
						|
 | 
						|
    When we get the connect() request from the connection pool (via the TLS wrapper),
 | 
						|
    we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
 | 
						|
    CONNECT request. Once that completes, we invoke the protocolFactory which was passed
 | 
						|
    in.
 | 
						|
 | 
						|
    Args:
 | 
						|
        reactor: the Twisted reactor to use for the connection
 | 
						|
        proxy_endpoint: the endpoint to use to connect to the proxy
 | 
						|
        host: hostname that we want to CONNECT to
 | 
						|
        port: port that we want to connect to
 | 
						|
        proxy_creds: credentials to authenticate at proxy
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        reactor: IReactorCore,
 | 
						|
        proxy_endpoint: IStreamClientEndpoint,
 | 
						|
        host: bytes,
 | 
						|
        port: int,
 | 
						|
        proxy_creds: Optional[ProxyCredentials],
 | 
						|
    ):
 | 
						|
        self._reactor = reactor
 | 
						|
        self._proxy_endpoint = proxy_endpoint
 | 
						|
        self._host = host
 | 
						|
        self._port = port
 | 
						|
        self._proxy_creds = proxy_creds
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 | 
						|
 | 
						|
    # Mypy encounters a false positive here: it complains that ClientFactory
 | 
						|
    # is incompatible with IProtocolFactory. But ClientFactory inherits from
 | 
						|
    # Factory, which implements IProtocolFactory. So I think this is a bug
 | 
						|
    # in mypy-zope.
 | 
						|
    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
 | 
						|
        f = HTTPProxiedClientFactory(
 | 
						|
            self._host, self._port, protocolFactory, self._proxy_creds
 | 
						|
        )
 | 
						|
        d = self._proxy_endpoint.connect(f)
 | 
						|
        # once the tcp socket connects successfully, we need to wait for the
 | 
						|
        # CONNECT to complete.
 | 
						|
        d.addCallback(lambda conn: f.on_connection)
 | 
						|
        return d
 | 
						|
 | 
						|
 | 
						|
class HTTPProxiedClientFactory(protocol.ClientFactory):
 | 
						|
    """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
 | 
						|
 | 
						|
    Once the CONNECT completes, invokes the original ClientFactory to build the
 | 
						|
    HTTP Protocol object and run the rest of the connection.
 | 
						|
 | 
						|
    Args:
 | 
						|
        dst_host: hostname that we want to CONNECT to
 | 
						|
        dst_port: port that we want to connect to
 | 
						|
        wrapped_factory: The original Factory
 | 
						|
        proxy_creds: credentials to authenticate at proxy
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        dst_host: bytes,
 | 
						|
        dst_port: int,
 | 
						|
        wrapped_factory: ClientFactory,
 | 
						|
        proxy_creds: Optional[ProxyCredentials],
 | 
						|
    ):
 | 
						|
        self.dst_host = dst_host
 | 
						|
        self.dst_port = dst_port
 | 
						|
        self.wrapped_factory = wrapped_factory
 | 
						|
        self.proxy_creds = proxy_creds
 | 
						|
        self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 | 
						|
 | 
						|
    def startedConnecting(self, connector):
 | 
						|
        return self.wrapped_factory.startedConnecting(connector)
 | 
						|
 | 
						|
    def buildProtocol(self, addr):
 | 
						|
        wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
 | 
						|
        if wrapped_protocol is None:
 | 
						|
            raise TypeError("buildProtocol produced None instead of a Protocol")
 | 
						|
 | 
						|
        return HTTPConnectProtocol(
 | 
						|
            self.dst_host,
 | 
						|
            self.dst_port,
 | 
						|
            wrapped_protocol,
 | 
						|
            self.on_connection,
 | 
						|
            self.proxy_creds,
 | 
						|
        )
 | 
						|
 | 
						|
    def clientConnectionFailed(self, connector, reason):
 | 
						|
        logger.debug("Connection to proxy failed: %s", reason)
 | 
						|
        if not self.on_connection.called:
 | 
						|
            self.on_connection.errback(reason)
 | 
						|
        return self.wrapped_factory.clientConnectionFailed(connector, reason)
 | 
						|
 | 
						|
    def clientConnectionLost(self, connector, reason):
 | 
						|
        logger.debug("Connection to proxy lost: %s", reason)
 | 
						|
        if not self.on_connection.called:
 | 
						|
            self.on_connection.errback(reason)
 | 
						|
        return self.wrapped_factory.clientConnectionLost(connector, reason)
 | 
						|
 | 
						|
 | 
						|
class HTTPConnectProtocol(protocol.Protocol):
 | 
						|
    """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
 | 
						|
 | 
						|
    Args:
 | 
						|
        host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
 | 
						|
            to put in the CONNECT request
 | 
						|
 | 
						|
        port: The original HTTP(s) port to put in the CONNECT request
 | 
						|
 | 
						|
        wrapped_protocol: the original protocol (probably HTTPChannel or
 | 
						|
            TLSMemoryBIOProtocol, but could be anything really)
 | 
						|
 | 
						|
        connected_deferred: a Deferred which will be callbacked with
 | 
						|
            wrapped_protocol when the CONNECT completes
 | 
						|
 | 
						|
        proxy_creds: credentials to authenticate at proxy
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        host: bytes,
 | 
						|
        port: int,
 | 
						|
        wrapped_protocol: Protocol,
 | 
						|
        connected_deferred: defer.Deferred,
 | 
						|
        proxy_creds: Optional[ProxyCredentials],
 | 
						|
    ):
 | 
						|
        self.host = host
 | 
						|
        self.port = port
 | 
						|
        self.wrapped_protocol = wrapped_protocol
 | 
						|
        self.connected_deferred = connected_deferred
 | 
						|
        self.proxy_creds = proxy_creds
 | 
						|
 | 
						|
        self.http_setup_client = HTTPConnectSetupClient(
 | 
						|
            self.host, self.port, self.proxy_creds
 | 
						|
        )
 | 
						|
        self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 | 
						|
 | 
						|
    def connectionMade(self):
 | 
						|
        self.http_setup_client.makeConnection(self.transport)
 | 
						|
 | 
						|
    def connectionLost(self, reason=connectionDone):
 | 
						|
        if self.wrapped_protocol.connected:
 | 
						|
            self.wrapped_protocol.connectionLost(reason)
 | 
						|
 | 
						|
        self.http_setup_client.connectionLost(reason)
 | 
						|
 | 
						|
        if not self.connected_deferred.called:
 | 
						|
            self.connected_deferred.errback(reason)
 | 
						|
 | 
						|
    def proxyConnected(self, _):
 | 
						|
        self.wrapped_protocol.makeConnection(self.transport)
 | 
						|
 | 
						|
        self.connected_deferred.callback(self.wrapped_protocol)
 | 
						|
 | 
						|
        # Get any pending data from the http buf and forward it to the original protocol
 | 
						|
        buf = self.http_setup_client.clearLineBuffer()
 | 
						|
        if buf:
 | 
						|
            self.wrapped_protocol.dataReceived(buf)
 | 
						|
 | 
						|
    def dataReceived(self, data: bytes):
 | 
						|
        # if we've set up the HTTP protocol, we can send the data there
 | 
						|
        if self.wrapped_protocol.connected:
 | 
						|
            return self.wrapped_protocol.dataReceived(data)
 | 
						|
 | 
						|
        # otherwise, we must still be setting up the connection: send the data to the
 | 
						|
        # setup client
 | 
						|
        return self.http_setup_client.dataReceived(data)
 | 
						|
 | 
						|
 | 
						|
class HTTPConnectSetupClient(http.HTTPClient):
 | 
						|
    """HTTPClient protocol to send a CONNECT message for proxies and read the response.
 | 
						|
 | 
						|
    Args:
 | 
						|
        host: The hostname to send in the CONNECT message
 | 
						|
        port: The port to send in the CONNECT message
 | 
						|
        proxy_creds: credentials to authenticate at proxy
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        host: bytes,
 | 
						|
        port: int,
 | 
						|
        proxy_creds: Optional[ProxyCredentials],
 | 
						|
    ):
 | 
						|
        self.host = host
 | 
						|
        self.port = port
 | 
						|
        self.proxy_creds = proxy_creds
 | 
						|
        self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 | 
						|
 | 
						|
    def connectionMade(self):
 | 
						|
        logger.debug("Connected to proxy, sending CONNECT")
 | 
						|
        self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 | 
						|
 | 
						|
        # Determine whether we need to set Proxy-Authorization headers
 | 
						|
        if self.proxy_creds:
 | 
						|
            # Set a Proxy-Authorization header
 | 
						|
            self.sendHeader(
 | 
						|
                b"Proxy-Authorization",
 | 
						|
                self.proxy_creds.as_proxy_authorization_value(),
 | 
						|
            )
 | 
						|
 | 
						|
        self.endHeaders()
 | 
						|
 | 
						|
    def handleStatus(self, version: bytes, status: bytes, message: bytes):
 | 
						|
        logger.debug("Got Status: %s %s %s", status, message, version)
 | 
						|
        if status != b"200":
 | 
						|
            raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 | 
						|
 | 
						|
    def handleEndHeaders(self):
 | 
						|
        logger.debug("End Headers")
 | 
						|
        self.on_connected.callback(None)
 | 
						|
 | 
						|
    def handleResponse(self, body):
 | 
						|
        pass
 |